Skip to content

Objectives

gpjax.objectives

tfd = tfp.distributions module-attribute
ConjugatePosterior = TypeVar('ConjugatePosterior', bound='gpjax.gps.ConjugatePosterior') module-attribute
NonConjugatePosterior = TypeVar('NonConjugatePosterior', bound='gpjax.gps.NonConjugatePosterior') module-attribute
VariationalFamily = TypeVar('VariationalFamily', bound='gpjax.variational_families.AbstractVariationalFamily') module-attribute
NonConjugateMLL = LogPosteriorDensity module-attribute
AbstractObjective dataclass

Bases: Module

Abstract base class for objectives.

negative: bool = static_field(False) class-attribute instance-attribute
constant: ScalarFloat = static_field(init=False, repr=False) class-attribute instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self

Replace the values of the fields of the object.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_meta(**kwargs: Any) -> Self

Replace the metadata of the fields.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the metadata of the fields of the object.

{}
Returns
Module: with the metadata of the fields replaced.
update_meta(**kwargs: Any) -> Self

Update the metadata of the fields. The metadata must already exist.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_trainable(**kwargs: Dict[str, bool]) -> Self

Replace the trainability status of local nodes of the Module.

replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self

Replace the bijectors of local nodes of the Module.

constrain() -> Self

Transform model parameters to the constrained space according to their defined bijectors.

Returns
Module: transformed to the constrained space.
unconstrain() -> Self

Transform model parameters to the unconstrained space according to their defined bijectors.

Returns
Module: transformed to the unconstrained space.
stop_gradient() -> Self

Stop gradients flowing through the Module.

Returns
Module: with gradients stopped.
trainables() -> Self
__init__(negative: bool = static_field(False), constant: ScalarFloat = static_field(init=False, repr=False)) -> None
__post_init__() -> None
__hash__()
__call__(*args, **kwargs) -> ScalarFloat
step(*args, **kwargs) -> ScalarFloat abstractmethod
ConjugateMLL dataclass

Bases: AbstractObjective

negative: bool = static_field(False) class-attribute instance-attribute
constant: ScalarFloat = static_field(init=False, repr=False) class-attribute instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self

Replace the values of the fields of the object.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_meta(**kwargs: Any) -> Self

Replace the metadata of the fields.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the metadata of the fields of the object.

{}
Returns
Module: with the metadata of the fields replaced.
update_meta(**kwargs: Any) -> Self

Update the metadata of the fields. The metadata must already exist.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_trainable(**kwargs: Dict[str, bool]) -> Self

Replace the trainability status of local nodes of the Module.

replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self

Replace the bijectors of local nodes of the Module.

constrain() -> Self

Transform model parameters to the constrained space according to their defined bijectors.

Returns
Module: transformed to the constrained space.
unconstrain() -> Self

Transform model parameters to the unconstrained space according to their defined bijectors.

Returns
Module: transformed to the unconstrained space.
stop_gradient() -> Self

Stop gradients flowing through the Module.

Returns
Module: with gradients stopped.
trainables() -> Self
__post_init__() -> None
__hash__()
__call__(*args, **kwargs) -> ScalarFloat
__init__(negative: bool = static_field(False), constant: ScalarFloat = static_field(init=False, repr=False)) -> None
step(posterior: ConjugatePosterior, train_data: Dataset) -> ScalarFloat

Evaluate the marginal log-likelihood of the Gaussian process.

Compute the marginal log-likelihood function of the Gaussian process. The returned function can then be used for gradient based optimisation of the model's parameters or for model comparison. The implementation given here enables exact estimation of the Gaussian process' latent function values.

For a training dataset {xn,yn}n=1N\{x_n, y_n\}_{n=1}^N, set of test inputs x⋆\mathbf{x}^{\star} the corresponding latent function evaluations are given by f=f(x)\mathbf{f}=f(\mathbf{x}) and f⋆f(x⋆)\mathbf{f}^{\star}f(\mathbf{x}^{\star}), the marginal log-likelihood is given by:

log⁑p(y)=∫p(y∣f)p(f,f⋆df⋆=0.5(βˆ’y⊀(k(x,xβ€²)+Οƒ2IN)βˆ’1yβˆ’log⁑∣k(x,xβ€²)+Οƒ2INβˆ£βˆ’nlog⁑2Ο€). \begin{align} \log p(\mathbf{y}) & = \int p(\mathbf{y}\mid\mathbf{f})p(\mathbf{f}, \mathbf{f}^{\star}\mathrm{d}\mathbf{f}^{\star}\\ &=0.5\left(-\mathbf{y}^{\top}\left(k(\mathbf{x}, \mathbf{x}') +\sigma^2\mathbf{I}_N \right)^{-1}\mathbf{y}-\log\lvert k(\mathbf{x}, \mathbf{x}') + \sigma^2\mathbf{I}_N\rvert - n\log 2\pi \right). \end{align}

For a given ConjugatePosterior object, the following code snippet shows how the marginal log-likelihood can be evaluated.

Example:

    >>> import gpjax as gpx
    >>>
    >>> xtrain = jnp.linspace(0, 1).reshape(-1, 1)
    >>> ytrain = jnp.sin(xtrain)
    >>> D = gpx.Dataset(X=xtrain, y=ytrain)
    >>>
    >>> meanf = gpx.mean_functions.Constant()
    >>> kernel = gpx.kernels.RBF()
    >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
    >>> prior = gpx.gps.Prior(mean_function = meanf, kernel=kernel)
    >>> posterior = prior * likelihood
    >>>
    >>> mll = gpx.objectives.ConjugateMLL(negative=True)
    >>> mll(posterior, train_data = D)

Our goal is to maximise the marginal log-likelihood. Therefore, when optimising the model's parameters with respect to the parameters, we use the negative marginal log-likelihood. This can be realised through

    mll = gpx.objectives.ConjugateMLL(negative=True)

For optimal performance, the marginal log-likelihood should be jax.jit compiled.

    mll = jit(gpx.objectives.ConjugateMLL(negative=True))

Parameters:

Name Type Description Default
posterior ConjugatePosterior

The posterior distribution for which we want to compute the marginal log-likelihood.

required
train_data Dataset

The training dataset used to compute the marginal log-likelihood.

required
Returns
ScalarFloat: The marginal log-likelihood of the Gaussian process for the
    current parameter set.
ConjugateLOOCV dataclass

Bases: AbstractObjective

negative: bool = static_field(False) class-attribute instance-attribute
constant: ScalarFloat = static_field(init=False, repr=False) class-attribute instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self

Replace the values of the fields of the object.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_meta(**kwargs: Any) -> Self

Replace the metadata of the fields.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the metadata of the fields of the object.

{}
Returns
Module: with the metadata of the fields replaced.
update_meta(**kwargs: Any) -> Self

Update the metadata of the fields. The metadata must already exist.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_trainable(**kwargs: Dict[str, bool]) -> Self

Replace the trainability status of local nodes of the Module.

replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self

Replace the bijectors of local nodes of the Module.

constrain() -> Self

Transform model parameters to the constrained space according to their defined bijectors.

Returns
Module: transformed to the constrained space.
unconstrain() -> Self

Transform model parameters to the unconstrained space according to their defined bijectors.

Returns
Module: transformed to the unconstrained space.
stop_gradient() -> Self

Stop gradients flowing through the Module.

Returns
Module: with gradients stopped.
trainables() -> Self
__post_init__() -> None
__hash__()
__call__(*args, **kwargs) -> ScalarFloat
__init__(negative: bool = static_field(False), constant: ScalarFloat = static_field(init=False, repr=False)) -> None
step(posterior: ConjugatePosterior, train_data: Dataset) -> ScalarFloat

Evaluate the leave-one-out log predictive probability of the Gaussian process following section 5.4.2 of Rasmussen et al. 2006 - Gaussian Processes for Machine Learning. This metric calculates the average performance of all models that can be obtained by training on all but one data point, and then predicting the left out data point.

The returned metric can then be used for gradient based optimisation of the model's parameters or for model comparison. The implementation given here enables exact estimation of the Gaussian process' latent function values.

For a given ConjugatePosterior object, the following code snippet shows how the leave-one-out log predicitive probability can be evaluated.

Example:

    >>> import gpjax as gpx
    >>>
    >>> xtrain = jnp.linspace(0, 1).reshape(-1, 1)
    >>> ytrain = jnp.sin(xtrain)
    >>> D = gpx.Dataset(X=xtrain, y=ytrain)
    >>>
    >>> meanf = gpx.mean_functions.Constant()
    >>> kernel = gpx.kernels.RBF()
    >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
    >>> prior = gpx.gps.Prior(mean_function = meanf, kernel=kernel)
    >>> posterior = prior * likelihood
    >>>
    >>> loocv = gpx.objectives.ConjugateLOOCV(negative=True)
    >>> loocv(posterior, train_data = D)

Our goal is to maximise the leave-one-out log predictive probability. Therefore, when optimising the model's parameters with respect to the parameters, we use the negative leave-one-out log predictive probability. This can be realised through

    mll = gpx.objectives.ConjugateLOOCV(negative=True)

For optimal performance, the objective should be jax.jit compiled.

    mll = jit(gpx.objectives.ConjugateLOOCV(negative=True))

Parameters:

Name Type Description Default
posterior ConjugatePosterior

The posterior distribution for which we want to compute the leave-one-out log predictive probability.

required
train_data Dataset

The training dataset used to compute the leave-one-out log predictive probability..

required
Returns
ScalarFloat: The leave-one-out log predictive probability of the Gaussian
    process for the current parameter set.
LogPosteriorDensity dataclass

Bases: AbstractObjective

The log-posterior density of a non-conjugate Gaussian process. This is sometimes referred to as the marginal log-likelihood.

negative: bool = static_field(False) class-attribute instance-attribute
constant: ScalarFloat = static_field(init=False, repr=False) class-attribute instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self

Replace the values of the fields of the object.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_meta(**kwargs: Any) -> Self

Replace the metadata of the fields.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the metadata of the fields of the object.

{}
Returns
Module: with the metadata of the fields replaced.
update_meta(**kwargs: Any) -> Self

Update the metadata of the fields. The metadata must already exist.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_trainable(**kwargs: Dict[str, bool]) -> Self

Replace the trainability status of local nodes of the Module.

replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self

Replace the bijectors of local nodes of the Module.

constrain() -> Self

Transform model parameters to the constrained space according to their defined bijectors.

Returns
Module: transformed to the constrained space.
unconstrain() -> Self

Transform model parameters to the unconstrained space according to their defined bijectors.

Returns
Module: transformed to the unconstrained space.
stop_gradient() -> Self

Stop gradients flowing through the Module.

Returns
Module: with gradients stopped.
trainables() -> Self
__post_init__() -> None
__hash__()
__call__(*args, **kwargs) -> ScalarFloat
__init__(negative: bool = static_field(False), constant: ScalarFloat = static_field(init=False, repr=False)) -> None
step(posterior: NonConjugatePosterior, data: Dataset) -> ScalarFloat

Evaluate the log-posterior density of a Gaussian process.

Compute the marginal log-likelihood, or log-posterior density of the Gaussian process. The returned function can then be used for gradient based optimisation of the model's parameters or for model comparison. The implementation given here is general and will work for any likelihood support by GPJax.

Unlike the marginal_log_likelihood function of the ConjugatePosterior object, the marginal_log_likelihood function of the NonConjugatePosterior object does not provide an exact marginal log-likelihood function. Instead, the NonConjugatePosterior object represents the posterior distributions as a function of the model's hyperparameters and the latent function. Markov chain Monte Carlo, variational inference, or Laplace approximations can then be used to sample from, or optimise an approximation to, the posterior distribution.

Parameters:

Name Type Description Default
posterior NonConjugatePosterior

The posterior distribution for which we want to compute the marginal log-likelihood.

required
data Dataset

The training dataset used to compute the marginal log-likelihood.

required
Returns
ScalarFloat: The log-posterior density of the Gaussian process for the
    current parameter set.
ELBO dataclass

Bases: AbstractObjective

negative: bool = static_field(False) class-attribute instance-attribute
constant: ScalarFloat = static_field(init=False, repr=False) class-attribute instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self

Replace the values of the fields of the object.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_meta(**kwargs: Any) -> Self

Replace the metadata of the fields.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the metadata of the fields of the object.

{}
Returns
Module: with the metadata of the fields replaced.
update_meta(**kwargs: Any) -> Self

Update the metadata of the fields. The metadata must already exist.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_trainable(**kwargs: Dict[str, bool]) -> Self

Replace the trainability status of local nodes of the Module.

replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self

Replace the bijectors of local nodes of the Module.

constrain() -> Self

Transform model parameters to the constrained space according to their defined bijectors.

Returns
Module: transformed to the constrained space.
unconstrain() -> Self

Transform model parameters to the unconstrained space according to their defined bijectors.

Returns
Module: transformed to the unconstrained space.
stop_gradient() -> Self

Stop gradients flowing through the Module.

Returns
Module: with gradients stopped.
trainables() -> Self
__post_init__() -> None
__hash__()
__call__(*args, **kwargs) -> ScalarFloat
__init__(negative: bool = static_field(False), constant: ScalarFloat = static_field(init=False, repr=False)) -> None
step(variational_family: VariationalFamily, train_data: Dataset) -> ScalarFloat

Compute the evidence lower bound of a variational approximation.

Compute the evidence lower bound under this model. In short, this requires evaluating the expectation of the model's log-likelihood under the variational approximation. To this, we sum the KL divergence from the variational posterior to the prior. When batching occurs, the result is scaled by the batch size relative to the full dataset size.

Parameters:

Name Type Description Default
variational_family AbstractVariationalFamily

The variational approximation for whose parameters we should maximise the ELBO with respect to.

required
train_data Dataset

The training data for which we should maximise the ELBO with respect to.

required
Returns
ScalarFloat: The evidence lower bound of the variational approximation for
    the current model parameter set.
CollapsedELBO dataclass

Bases: AbstractObjective

The collapsed evidence lower bound.

Collapsed variational inference for a sparse Gaussian process regression model. The key reference is Titsias, (2009) - Variational Learning of Inducing Variables in Sparse Gaussian Processes.

negative: bool = static_field(False) class-attribute instance-attribute
constant: ScalarFloat = static_field(init=False, repr=False) class-attribute instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self

Replace the values of the fields of the object.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_meta(**kwargs: Any) -> Self

Replace the metadata of the fields.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the metadata of the fields of the object.

{}
Returns
Module: with the metadata of the fields replaced.
update_meta(**kwargs: Any) -> Self

Update the metadata of the fields. The metadata must already exist.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_trainable(**kwargs: Dict[str, bool]) -> Self

Replace the trainability status of local nodes of the Module.

replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self

Replace the bijectors of local nodes of the Module.

constrain() -> Self

Transform model parameters to the constrained space according to their defined bijectors.

Returns
Module: transformed to the constrained space.
unconstrain() -> Self

Transform model parameters to the unconstrained space according to their defined bijectors.

Returns
Module: transformed to the unconstrained space.
stop_gradient() -> Self

Stop gradients flowing through the Module.

Returns
Module: with gradients stopped.
trainables() -> Self
__post_init__() -> None
__hash__()
__call__(*args, **kwargs) -> ScalarFloat
__init__(negative: bool = static_field(False), constant: ScalarFloat = static_field(init=False, repr=False)) -> None
step(variational_family: VariationalFamily, train_data: Dataset) -> ScalarFloat

Compute a single step of the collapsed evidence lower bound.

Compute the evidence lower bound under this model. In short, this requires evaluating the expectation of the model's log-likelihood under the variational approximation. To this, we sum the KL divergence from the variational posterior to the prior. When batching occurs, the result is scaled by the batch size relative to the full dataset size.

Parameters:

Name Type Description Default
variational_family AbstractVariationalFamily

The variational approximation for whose parameters we should maximise the ELBO with respect to.

required
train_data Dataset

The training data for which we should maximise the ELBO with respect to.

required
Returns
ScalarFloat: The evidence lower bound of the variational approximation for
    the current model parameter set.
variational_expectation(variational_family: VariationalFamily, train_data: Dataset) -> Float[Array, ' N']

Compute the variational expectation.

Compute the expectation of our model's log-likelihood under our variational distribution. Batching can be done here to speed up computation.

Parameters:

Name Type Description Default
variational_family AbstractVariationalFamily

The variational family that we are using to approximate the posterior.

required
train_data Dataset

The batch for which the expectation should be computed for.

required
Returns
Array: The expectation of the model's log-likelihood under our variational
    distribution.