Objectives
gpjax.objectives
tfd = tfp.distributions
module-attribute
ConjugatePosterior = TypeVar('ConjugatePosterior', bound='gpjax.gps.ConjugatePosterior')
module-attribute
NonConjugatePosterior = TypeVar('NonConjugatePosterior', bound='gpjax.gps.NonConjugatePosterior')
module-attribute
VariationalFamily = TypeVar('VariationalFamily', bound='gpjax.variational_families.AbstractVariationalFamily')
module-attribute
NonConjugateMLL = LogPosteriorDensity
module-attribute
AbstractObjective
dataclass
Bases: Module
Abstract base class for objectives.
negative: bool = static_field(False)
class-attribute
instance-attribute
constant: ScalarFloat = static_field(init=False, repr=False)
class-attribute
instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self
replace_meta(**kwargs: Any) -> Self
update_meta(**kwargs: Any) -> Self
replace_trainable(**kwargs: Dict[str, bool]) -> Self
Replace the trainability status of local nodes of the Module.
replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self
Replace the bijectors of local nodes of the Module.
constrain() -> Self
unconstrain() -> Self
stop_gradient() -> Self
trainables() -> Self
__post_init__() -> None
__hash__()
__call__(*args, **kwargs) -> ScalarFloat
step(*args, **kwargs) -> ScalarFloat
abstractmethod
ConjugateMLL
Bases: AbstractObjective
negative: bool = static_field(False)
class-attribute
instance-attribute
constant: ScalarFloat = static_field(init=False, repr=False)
class-attribute
instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self
replace_meta(**kwargs: Any) -> Self
update_meta(**kwargs: Any) -> Self
replace_trainable(**kwargs: Dict[str, bool]) -> Self
Replace the trainability status of local nodes of the Module.
replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self
Replace the bijectors of local nodes of the Module.
constrain() -> Self
unconstrain() -> Self
stop_gradient() -> Self
trainables() -> Self
__post_init__() -> None
__hash__()
__call__(*args, **kwargs) -> ScalarFloat
step(posterior: ConjugatePosterior, train_data: Dataset) -> ScalarFloat
Evaluate the marginal log-likelihood of the Gaussian process.
Compute the marginal log-likelihood function of the Gaussian process. The returned function can then be used for gradient based optimisation of the model's parameters or for model comparison. The implementation given here enables exact estimation of the Gaussian process' latent function values.
For a training dataset , set of test inputs the corresponding latent function evaluations are given by and , the marginal log-likelihood is given by:
For a given ConjugatePosterior
object, the following code snippet shows
how the marginal log-likelihood can be evaluated.
Example:
>>> import gpjax as gpx
>>>
>>> xtrain = jnp.linspace(0, 1).reshape(-1, 1)
>>> ytrain = jnp.sin(xtrain)
>>> D = gpx.Dataset(X=xtrain, y=ytrain)
>>>
>>> meanf = gpx.mean_functions.Constant()
>>> kernel = gpx.kernels.RBF()
>>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
>>> prior = gpx.gps.Prior(mean_function = meanf, kernel=kernel)
>>> posterior = prior * likelihood
>>>
>>> mll = gpx.objectives.ConjugateMLL(negative=True)
>>> mll(posterior, train_data = D)
Our goal is to maximise the marginal log-likelihood. Therefore, when optimising the model's parameters with respect to the parameters, we use the negative marginal log-likelihood. This can be realised through
For optimal performance, the marginal log-likelihood should be jax.jit
compiled.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
posterior |
ConjugatePosterior
|
The posterior distribution for which we want to compute the marginal log-likelihood. |
required |
train_data |
Dataset
|
The training dataset used to compute the marginal log-likelihood. |
required |
Returns
ScalarFloat: The marginal log-likelihood of the Gaussian process for the
current parameter set.
ConjugateLOOCV
Bases: AbstractObjective
negative: bool = static_field(False)
class-attribute
instance-attribute
constant: ScalarFloat = static_field(init=False, repr=False)
class-attribute
instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self
replace_meta(**kwargs: Any) -> Self
update_meta(**kwargs: Any) -> Self
replace_trainable(**kwargs: Dict[str, bool]) -> Self
Replace the trainability status of local nodes of the Module.
replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self
Replace the bijectors of local nodes of the Module.
constrain() -> Self
unconstrain() -> Self
stop_gradient() -> Self
trainables() -> Self
__post_init__() -> None
__hash__()
__call__(*args, **kwargs) -> ScalarFloat
step(posterior: ConjugatePosterior, train_data: Dataset) -> ScalarFloat
Evaluate the leave-one-out log predictive probability of the Gaussian process following section 5.4.2 of Rasmussen et al. 2006 - Gaussian Processes for Machine Learning. This metric calculates the average performance of all models that can be obtained by training on all but one data point, and then predicting the left out data point.
The returned metric can then be used for gradient based optimisation of the model's parameters or for model comparison. The implementation given here enables exact estimation of the Gaussian process' latent function values.
For a given ConjugatePosterior
object, the following code snippet shows
how the leave-one-out log predicitive probability can be evaluated.
Example:
>>> import gpjax as gpx
>>>
>>> xtrain = jnp.linspace(0, 1).reshape(-1, 1)
>>> ytrain = jnp.sin(xtrain)
>>> D = gpx.Dataset(X=xtrain, y=ytrain)
>>>
>>> meanf = gpx.mean_functions.Constant()
>>> kernel = gpx.kernels.RBF()
>>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
>>> prior = gpx.gps.Prior(mean_function = meanf, kernel=kernel)
>>> posterior = prior * likelihood
>>>
>>> loocv = gpx.objectives.ConjugateLOOCV(negative=True)
>>> loocv(posterior, train_data = D)
Our goal is to maximise the leave-one-out log predictive probability. Therefore, when optimising the model's parameters with respect to the parameters, we use the negative leave-one-out log predictive probability. This can be realised through
For optimal performance, the objective should be jax.jit
compiled.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
posterior |
ConjugatePosterior
|
The posterior distribution for which we want to compute the leave-one-out log predictive probability. |
required |
train_data |
Dataset
|
The training dataset used to compute the leave-one-out log predictive probability.. |
required |
Returns
ScalarFloat: The leave-one-out log predictive probability of the Gaussian
process for the current parameter set.
LogPosteriorDensity
Bases: AbstractObjective
The log-posterior density of a non-conjugate Gaussian process. This is sometimes referred to as the marginal log-likelihood.
negative: bool = static_field(False)
class-attribute
instance-attribute
constant: ScalarFloat = static_field(init=False, repr=False)
class-attribute
instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self
replace_meta(**kwargs: Any) -> Self
update_meta(**kwargs: Any) -> Self
replace_trainable(**kwargs: Dict[str, bool]) -> Self
Replace the trainability status of local nodes of the Module.
replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self
Replace the bijectors of local nodes of the Module.
constrain() -> Self
unconstrain() -> Self
stop_gradient() -> Self
trainables() -> Self
__post_init__() -> None
__hash__()
__call__(*args, **kwargs) -> ScalarFloat
step(posterior: NonConjugatePosterior, data: Dataset) -> ScalarFloat
Evaluate the log-posterior density of a Gaussian process.
Compute the marginal log-likelihood, or log-posterior density of the Gaussian process. The returned function can then be used for gradient based optimisation of the model's parameters or for model comparison. The implementation given here is general and will work for any likelihood support by GPJax.
Unlike the marginal_log_likelihood function of the ConjugatePosterior
object,
the marginal_log_likelihood function of the NonConjugatePosterior
object does
not provide an exact marginal log-likelihood function. Instead, the
NonConjugatePosterior
object represents the posterior distributions as a
function of the model's hyperparameters and the latent function. Markov chain
Monte Carlo, variational inference, or Laplace approximations can then be used
to sample from, or optimise an approximation to, the posterior distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
posterior |
NonConjugatePosterior
|
The posterior distribution for which we want to compute the marginal log-likelihood. |
required |
data |
Dataset
|
The training dataset used to compute the marginal log-likelihood. |
required |
Returns
ScalarFloat: The log-posterior density of the Gaussian process for the
current parameter set.
ELBO
Bases: AbstractObjective
negative: bool = static_field(False)
class-attribute
instance-attribute
constant: ScalarFloat = static_field(init=False, repr=False)
class-attribute
instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self
replace_meta(**kwargs: Any) -> Self
update_meta(**kwargs: Any) -> Self
replace_trainable(**kwargs: Dict[str, bool]) -> Self
Replace the trainability status of local nodes of the Module.
replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self
Replace the bijectors of local nodes of the Module.
constrain() -> Self
unconstrain() -> Self
stop_gradient() -> Self
trainables() -> Self
__post_init__() -> None
__hash__()
__call__(*args, **kwargs) -> ScalarFloat
step(variational_family: VariationalFamily, train_data: Dataset) -> ScalarFloat
Compute the evidence lower bound of a variational approximation.
Compute the evidence lower bound under this model. In short, this requires evaluating the expectation of the model's log-likelihood under the variational approximation. To this, we sum the KL divergence from the variational posterior to the prior. When batching occurs, the result is scaled by the batch size relative to the full dataset size.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
variational_family |
AbstractVariationalFamily
|
The variational approximation for whose parameters we should maximise the ELBO with respect to. |
required |
train_data |
Dataset
|
The training data for which we should maximise the ELBO with respect to. |
required |
Returns
ScalarFloat: The evidence lower bound of the variational approximation for
the current model parameter set.
CollapsedELBO
Bases: AbstractObjective
The collapsed evidence lower bound.
Collapsed variational inference for a sparse Gaussian process regression model. The key reference is Titsias, (2009) - Variational Learning of Inducing Variables in Sparse Gaussian Processes.
negative: bool = static_field(False)
class-attribute
instance-attribute
constant: ScalarFloat = static_field(init=False, repr=False)
class-attribute
instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self
replace_meta(**kwargs: Any) -> Self
update_meta(**kwargs: Any) -> Self
replace_trainable(**kwargs: Dict[str, bool]) -> Self
Replace the trainability status of local nodes of the Module.
replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self
Replace the bijectors of local nodes of the Module.
constrain() -> Self
unconstrain() -> Self
stop_gradient() -> Self
trainables() -> Self
__post_init__() -> None
__hash__()
__call__(*args, **kwargs) -> ScalarFloat
step(variational_family: VariationalFamily, train_data: Dataset) -> ScalarFloat
Compute a single step of the collapsed evidence lower bound.
Compute the evidence lower bound under this model. In short, this requires evaluating the expectation of the model's log-likelihood under the variational approximation. To this, we sum the KL divergence from the variational posterior to the prior. When batching occurs, the result is scaled by the batch size relative to the full dataset size.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
variational_family |
AbstractVariationalFamily
|
The variational approximation for whose parameters we should maximise the ELBO with respect to. |
required |
train_data |
Dataset
|
The training data for which we should maximise the ELBO with respect to. |
required |
Returns
ScalarFloat: The evidence lower bound of the variational approximation for
the current model parameter set.
variational_expectation(variational_family: VariationalFamily, train_data: Dataset) -> Float[Array, ' N']
Compute the variational expectation.
Compute the expectation of our model's log-likelihood under our variational distribution. Batching can be done here to speed up computation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
variational_family |
AbstractVariationalFamily
|
The variational family that we are using to approximate the posterior. |
required |
train_data |
Dataset
|
The batch for which the expectation should be computed for. |
required |
Returns
Array: The expectation of the model's log-likelihood under our variational
distribution.