Skip to content

Typing

gpjax.typing

OldKeyArray = UInt32[JAXArray, '2'] module-attribute
JAXKeyArray = Key[JAXArray, ''] module-attribute
KeyArray = Union[OldKeyArray, JAXKeyArray] module-attribute
Array = Union[JAXArray, NumpyArray] module-attribute
ScalarBool = Union[bool, Bool[Array, '']] module-attribute
ScalarInt = Union[int, Int[Array, '']] module-attribute
ScalarFloat = Union[float, Float[Array, '']] module-attribute
VecNOrMatNM = Union[Float[Array, ' N'], Float[Array, 'N M']] module-attribute
FunctionalSample = Callable[[Float[Array, 'N D']], Float[Array, 'N B']] module-attribute

Type alias for functions representing BB samples from a model, to be evaluated on any set of NN inputs (of dimension DD) and returning the evaluations of each (potentially approximate) sample draw across these inputs.

__all__ = ['KeyArray', 'ScalarBool', 'ScalarInt', 'ScalarFloat', 'FunctionalSample'] module-attribute