Skip to content

Mean Functions

gpjax.mean_functions

SumMeanFunction = partial(CombinationMeanFunction, operator=partial(jnp.sum, axis=0)) module-attribute
ProductMeanFunction = partial(CombinationMeanFunction, operator=partial(jnp.sum, axis=0)) module-attribute
AbstractMeanFunction dataclass

Bases: Module

Mean function that is used to parameterise the Gaussian process.

__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self

Replace the values of the fields of the object.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_meta(**kwargs: Any) -> Self

Replace the metadata of the fields.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the metadata of the fields of the object.

{}
Returns
Module: with the metadata of the fields replaced.
update_meta(**kwargs: Any) -> Self

Update the metadata of the fields. The metadata must already exist.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_trainable(**kwargs: Dict[str, bool]) -> Self

Replace the trainability status of local nodes of the Module.

replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self

Replace the bijectors of local nodes of the Module.

constrain() -> Self

Transform model parameters to the constrained space according to their defined bijectors.

Returns
Module: transformed to the constrained space.
unconstrain() -> Self

Transform model parameters to the unconstrained space according to their defined bijectors.

Returns
Module: transformed to the unconstrained space.
stop_gradient() -> Self

Stop gradients flowing through the Module.

Returns
Module: with gradients stopped.
trainables() -> Self
__init__() -> None
__call__(x: Num[Array, 'N D']) -> Float[Array, 'N O'] abstractmethod

Evaluate the mean function at the given points. This method is required for all subclasses.

Parameters:

Name Type Description Default
x Float[Array, ' D']

The point at which to evaluate the mean function.

required
Returns
Float[Array, "1]: The evaluated mean function.
__add__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction

Add two mean functions.

Parameters:

Name Type Description Default
other AbstractMeanFunction

The other mean function to add.

required
Returns
AbstractMeanFunction: The sum of the two mean functions.
__radd__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction

Add two mean functions.

Parameters:

Name Type Description Default
other AbstractMeanFunction

The other mean function to add.

required
Returns
AbstractMeanFunction: The sum of the two mean functions.
__mul__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction

Multiply two mean functions.

Parameters:

Name Type Description Default
other AbstractMeanFunction

The other mean function to multiply.

required
Returns
AbstractMeanFunction: The product of the two mean functions.
__rmul__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction

Multiply two mean functions.

Parameters:

Name Type Description Default
other AbstractMeanFunction

The other mean function to multiply.

required
Returns
AbstractMeanFunction: The product of the two mean functions.
Constant dataclass

Bases: AbstractMeanFunction

Constant mean function.

A constant mean function. This function returns a repeated scalar value for all inputs. The scalar value itself can be treated as a model hyperparameter and learned during training but defaults to 1.0.

constant: Float[Array, ' O'] = param_field(jnp.array([0.0])) class-attribute instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self

Replace the values of the fields of the object.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_meta(**kwargs: Any) -> Self

Replace the metadata of the fields.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the metadata of the fields of the object.

{}
Returns
Module: with the metadata of the fields replaced.
update_meta(**kwargs: Any) -> Self

Update the metadata of the fields. The metadata must already exist.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_trainable(**kwargs: Dict[str, bool]) -> Self

Replace the trainability status of local nodes of the Module.

replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self

Replace the bijectors of local nodes of the Module.

constrain() -> Self

Transform model parameters to the constrained space according to their defined bijectors.

Returns
Module: transformed to the constrained space.
unconstrain() -> Self

Transform model parameters to the unconstrained space according to their defined bijectors.

Returns
Module: transformed to the unconstrained space.
stop_gradient() -> Self

Stop gradients flowing through the Module.

Returns
Module: with gradients stopped.
trainables() -> Self
__add__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction

Add two mean functions.

Parameters:

Name Type Description Default
other AbstractMeanFunction

The other mean function to add.

required
Returns
AbstractMeanFunction: The sum of the two mean functions.
__radd__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction

Add two mean functions.

Parameters:

Name Type Description Default
other AbstractMeanFunction

The other mean function to add.

required
Returns
AbstractMeanFunction: The sum of the two mean functions.
__mul__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction

Multiply two mean functions.

Parameters:

Name Type Description Default
other AbstractMeanFunction

The other mean function to multiply.

required
Returns
AbstractMeanFunction: The product of the two mean functions.
__rmul__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction

Multiply two mean functions.

Parameters:

Name Type Description Default
other AbstractMeanFunction

The other mean function to multiply.

required
Returns
AbstractMeanFunction: The product of the two mean functions.
__init__(constant: Float[Array, ' O'] = param_field(jnp.array([0.0]))) -> None
__call__(x: Num[Array, 'N D']) -> Float[Array, 'N O']

Evaluate the mean function at the given points.

Parameters:

Name Type Description Default
x Float[Array, ' D']

The point at which to evaluate the mean function.

required
Returns
Float[Array, "1"]: The evaluated mean function.
Zero dataclass

Bases: Constant

Zero mean function.

The zero mean function. This function returns a zero scalar value for all inputs. Unlike the Constant mean function, the constant scalar zero is fixed, and cannot be treated as a model hyperparameter and learned during training.

constant: Float[Array, ' O'] = static_field(jnp.array([0.0]), init=False) class-attribute instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self

Replace the values of the fields of the object.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_meta(**kwargs: Any) -> Self

Replace the metadata of the fields.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the metadata of the fields of the object.

{}
Returns
Module: with the metadata of the fields replaced.
update_meta(**kwargs: Any) -> Self

Update the metadata of the fields. The metadata must already exist.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_trainable(**kwargs: Dict[str, bool]) -> Self

Replace the trainability status of local nodes of the Module.

replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self

Replace the bijectors of local nodes of the Module.

constrain() -> Self

Transform model parameters to the constrained space according to their defined bijectors.

Returns
Module: transformed to the constrained space.
unconstrain() -> Self

Transform model parameters to the unconstrained space according to their defined bijectors.

Returns
Module: transformed to the unconstrained space.
stop_gradient() -> Self

Stop gradients flowing through the Module.

Returns
Module: with gradients stopped.
trainables() -> Self
__call__(x: Num[Array, 'N D']) -> Float[Array, 'N O']

Evaluate the mean function at the given points.

Parameters:

Name Type Description Default
x Float[Array, ' D']

The point at which to evaluate the mean function.

required
Returns
Float[Array, "1"]: The evaluated mean function.
__add__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction

Add two mean functions.

Parameters:

Name Type Description Default
other AbstractMeanFunction

The other mean function to add.

required
Returns
AbstractMeanFunction: The sum of the two mean functions.
__radd__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction

Add two mean functions.

Parameters:

Name Type Description Default
other AbstractMeanFunction

The other mean function to add.

required
Returns
AbstractMeanFunction: The sum of the two mean functions.
__mul__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction

Multiply two mean functions.

Parameters:

Name Type Description Default
other AbstractMeanFunction

The other mean function to multiply.

required
Returns
AbstractMeanFunction: The product of the two mean functions.
__rmul__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction

Multiply two mean functions.

Parameters:

Name Type Description Default
other AbstractMeanFunction

The other mean function to multiply.

required
Returns
AbstractMeanFunction: The product of the two mean functions.
__init__(constant: Float[Array, ' O'] = static_field(jnp.array([0.0]), init=False)) -> None
CombinationMeanFunction dataclass

Bases: AbstractMeanFunction

A base class for products or sums of AbstractMeanFunctions.

means: List[AbstractMeanFunction] = items_list instance-attribute
operator: Callable = operator class-attribute instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self

Replace the values of the fields of the object.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_meta(**kwargs: Any) -> Self

Replace the metadata of the fields.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the metadata of the fields of the object.

{}
Returns
Module: with the metadata of the fields replaced.
update_meta(**kwargs: Any) -> Self

Update the metadata of the fields. The metadata must already exist.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_trainable(**kwargs: Dict[str, bool]) -> Self

Replace the trainability status of local nodes of the Module.

replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self

Replace the bijectors of local nodes of the Module.

constrain() -> Self

Transform model parameters to the constrained space according to their defined bijectors.

Returns
Module: transformed to the constrained space.
unconstrain() -> Self

Transform model parameters to the unconstrained space according to their defined bijectors.

Returns
Module: transformed to the unconstrained space.
stop_gradient() -> Self

Stop gradients flowing through the Module.

Returns
Module: with gradients stopped.
trainables() -> Self
__add__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction

Add two mean functions.

Parameters:

Name Type Description Default
other AbstractMeanFunction

The other mean function to add.

required
Returns
AbstractMeanFunction: The sum of the two mean functions.
__radd__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction

Add two mean functions.

Parameters:

Name Type Description Default
other AbstractMeanFunction

The other mean function to add.

required
Returns
AbstractMeanFunction: The sum of the two mean functions.
__mul__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction

Multiply two mean functions.

Parameters:

Name Type Description Default
other AbstractMeanFunction

The other mean function to multiply.

required
Returns
AbstractMeanFunction: The product of the two mean functions.
__rmul__(other: Union[AbstractMeanFunction, Float[Array, ' O']]) -> AbstractMeanFunction

Multiply two mean functions.

Parameters:

Name Type Description Default
other AbstractMeanFunction

The other mean function to multiply.

required
Returns
AbstractMeanFunction: The product of the two mean functions.
__init__(means: List[AbstractMeanFunction], operator: Callable, **kwargs) -> None
__call__(x: Num[Array, 'N D']) -> Float[Array, 'N O']

Evaluate combination kernel on a pair of inputs.

Parameters:

Name Type Description Default
x Float[Array, ' D']

The point at which to evaluate the mean function.

required
Returns
Float[Array, " Q"]: The evaluated mean function.