Mean Functions
gpjax.mean_functions
SumMeanFunction = partial(CombinationMeanFunction, operator=partial(jnp.sum, axis=0))
module-attribute
ProductMeanFunction = partial(CombinationMeanFunction, operator=partial(jnp.sum, axis=0))
module-attribute
AbstractMeanFunction
dataclass
Bases: Module
Mean function that is used to parameterise the Gaussian process.
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self
replace_meta(**kwargs: Any) -> Self
update_meta(**kwargs: Any) -> Self
replace_trainable(**kwargs: Dict[str, bool]) -> Self
Replace the trainability status of local nodes of the Module.
replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self
Replace the bijectors of local nodes of the Module.
constrain() -> Self
unconstrain() -> Self
stop_gradient() -> Self
trainables() -> Self
__call__(x: Num[Array, 'N D']) -> Float[Array, 'N O']
abstractmethod
Evaluate the mean function at the given points. This method is required for all subclasses.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Float[Array, ' D']
|
The point at which to evaluate the mean function. |
required |
Returns
Float[Array, "1]: The evaluated mean function.
__add__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction
Add two mean functions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
AbstractMeanFunction
|
The other mean function to add. |
required |
Returns
AbstractMeanFunction: The sum of the two mean functions.
__radd__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction
Add two mean functions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
AbstractMeanFunction
|
The other mean function to add. |
required |
Returns
AbstractMeanFunction: The sum of the two mean functions.
__mul__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction
Multiply two mean functions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
AbstractMeanFunction
|
The other mean function to multiply. |
required |
Returns
AbstractMeanFunction: The product of the two mean functions.
__rmul__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction
Multiply two mean functions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
AbstractMeanFunction
|
The other mean function to multiply. |
required |
Returns
AbstractMeanFunction: The product of the two mean functions.
Constant
dataclass
Bases: AbstractMeanFunction
Constant mean function.
A constant mean function. This function returns a repeated scalar value for all inputs. The scalar value itself can be treated as a model hyperparameter and learned during training but defaults to 1.0.
constant: Float[Array, ' O'] = param_field(jnp.array([0.0]))
class-attribute
instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self
replace_meta(**kwargs: Any) -> Self
update_meta(**kwargs: Any) -> Self
replace_trainable(**kwargs: Dict[str, bool]) -> Self
Replace the trainability status of local nodes of the Module.
replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self
Replace the bijectors of local nodes of the Module.
constrain() -> Self
unconstrain() -> Self
stop_gradient() -> Self
trainables() -> Self
__add__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction
Add two mean functions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
AbstractMeanFunction
|
The other mean function to add. |
required |
Returns
AbstractMeanFunction: The sum of the two mean functions.
__radd__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction
Add two mean functions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
AbstractMeanFunction
|
The other mean function to add. |
required |
Returns
AbstractMeanFunction: The sum of the two mean functions.
__mul__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction
Multiply two mean functions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
AbstractMeanFunction
|
The other mean function to multiply. |
required |
Returns
AbstractMeanFunction: The product of the two mean functions.
__rmul__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction
Multiply two mean functions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
AbstractMeanFunction
|
The other mean function to multiply. |
required |
Returns
AbstractMeanFunction: The product of the two mean functions.
__call__(x: Num[Array, 'N D']) -> Float[Array, 'N O']
Evaluate the mean function at the given points.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Float[Array, ' D']
|
The point at which to evaluate the mean function. |
required |
Returns
Float[Array, "1"]: The evaluated mean function.
Zero
dataclass
Bases: Constant
Zero mean function.
The zero mean function. This function returns a zero scalar value for all inputs. Unlike the Constant mean function, the constant scalar zero is fixed, and cannot be treated as a model hyperparameter and learned during training.
constant: Float[Array, ' O'] = static_field(jnp.array([0.0]), init=False)
class-attribute
instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self
replace_meta(**kwargs: Any) -> Self
update_meta(**kwargs: Any) -> Self
replace_trainable(**kwargs: Dict[str, bool]) -> Self
Replace the trainability status of local nodes of the Module.
replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self
Replace the bijectors of local nodes of the Module.
constrain() -> Self
unconstrain() -> Self
stop_gradient() -> Self
trainables() -> Self
__call__(x: Num[Array, 'N D']) -> Float[Array, 'N O']
Evaluate the mean function at the given points.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Float[Array, ' D']
|
The point at which to evaluate the mean function. |
required |
Returns
Float[Array, "1"]: The evaluated mean function.
__add__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction
Add two mean functions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
AbstractMeanFunction
|
The other mean function to add. |
required |
Returns
AbstractMeanFunction: The sum of the two mean functions.
__radd__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction
Add two mean functions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
AbstractMeanFunction
|
The other mean function to add. |
required |
Returns
AbstractMeanFunction: The sum of the two mean functions.
__mul__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction
Multiply two mean functions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
AbstractMeanFunction
|
The other mean function to multiply. |
required |
Returns
AbstractMeanFunction: The product of the two mean functions.
__rmul__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction
Multiply two mean functions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
AbstractMeanFunction
|
The other mean function to multiply. |
required |
Returns
AbstractMeanFunction: The product of the two mean functions.
CombinationMeanFunction
dataclass
Bases: AbstractMeanFunction
A base class for products or sums of AbstractMeanFunctions.
means: List[AbstractMeanFunction] = items_list
instance-attribute
operator: Callable = operator
class-attribute
instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self
replace_meta(**kwargs: Any) -> Self
update_meta(**kwargs: Any) -> Self
replace_trainable(**kwargs: Dict[str, bool]) -> Self
Replace the trainability status of local nodes of the Module.
replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self
Replace the bijectors of local nodes of the Module.
constrain() -> Self
unconstrain() -> Self
stop_gradient() -> Self
trainables() -> Self
__add__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction
Add two mean functions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
AbstractMeanFunction
|
The other mean function to add. |
required |
Returns
AbstractMeanFunction: The sum of the two mean functions.
__radd__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction
Add two mean functions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
AbstractMeanFunction
|
The other mean function to add. |
required |
Returns
AbstractMeanFunction: The sum of the two mean functions.
__mul__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction
Multiply two mean functions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
AbstractMeanFunction
|
The other mean function to multiply. |
required |
Returns
AbstractMeanFunction: The product of the two mean functions.
__rmul__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction
Multiply two mean functions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
AbstractMeanFunction
|
The other mean function to multiply. |
required |
Returns
AbstractMeanFunction: The product of the two mean functions.
__init__(means: List[AbstractMeanFunction], operator: Callable, **kwargs) -> None
__call__(x: Num[Array, 'N D']) -> Float[Array, 'N O']
Evaluate combination kernel on a pair of inputs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Float[Array, ' D']
|
The point at which to evaluate the mean function. |
required |
Returns
Float[Array, " Q"]: The evaluated mean function.