GPs
gpjax.gps
Kernel = TypeVar('Kernel', bound=AbstractKernel)
module-attribute
MeanFunction = TypeVar('MeanFunction', bound=AbstractMeanFunction)
module-attribute
Likelihood = TypeVar('Likelihood', bound=AbstractLikelihood)
module-attribute
NonGaussianLikelihood = TypeVar('NonGaussianLikelihood', bound=NonGaussian)
module-attribute
GaussianLikelihood = TypeVar('GaussianLikelihood', bound=Gaussian)
module-attribute
PriorType = TypeVar('PriorType', bound=AbstractPrior)
module-attribute
__all__ = ['AbstractPrior', 'Prior', 'AbstractPosterior', 'ConjugatePosterior', 'NonConjugatePosterior', 'construct_posterior']
module-attribute
AbstractPrior
dataclass
Bases: Module
, Generic[MeanFunction, Kernel]
Abstract Gaussian process prior.
kernel: Kernel
instance-attribute
mean_function: MeanFunction
instance-attribute
jitter: float = static_field(1e-06)
class-attribute
instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self
replace_meta(**kwargs: Any) -> Self
update_meta(**kwargs: Any) -> Self
replace_trainable(**kwargs: Dict[str, bool]) -> Self
Replace the trainability status of local nodes of the Module.
replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self
Replace the bijectors of local nodes of the Module.
constrain() -> Self
unconstrain() -> Self
stop_gradient() -> Self
trainables() -> Self
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution
Evaluate the Gaussian process at the given points.
The output of this function is a TensorFlow probability distribution from which the the latent function's mean and covariance can be evaluated and the distribution can be sampled.
Under the hood, __call__
is calling the objects predict
method. For this
reasons, classes inheriting the AbstractPrior
class, should not overwrite the
__call__
method and should instead define a predict
method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any
|
The arguments to pass to the GP's |
()
|
**kwargs |
Any
|
The keyword arguments to pass to the GP's |
{}
|
Returns
GaussianDistribution: A multivariate normal random variable representation
of the Gaussian process.
predict(*args: Any, **kwargs: Any) -> GaussianDistribution
abstractmethod
Evaluate the predictive distribution.
Compute the latent function's multivariate normal distribution for a
given set of parameters. For any class inheriting the AbstractPrior
class,
this method must be implemented.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any
|
Arguments to the predict method. |
()
|
**kwargs |
Any
|
Keyword arguments to the predict method. |
{}
|
Returns
GaussianDistribution: A multivariate normal random variable representation
of the Gaussian process.
Prior
dataclass
Bases: AbstractPrior[MeanFunction, Kernel]
A Gaussian process prior object.
The GP is parameterised by a mean and kernel function.
A Gaussian process prior parameterised by a mean function and a kernel function is given by .
To invoke a Prior
distribution, a kernel and mean function must be specified.
Example:
>>> import gpjax as gpx
>>> kernel = gpx.kernels.RBF()
>>> meanf = gpx.mean_functions.Zero()
>>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
kernel: Kernel
instance-attribute
mean_function: MeanFunction
instance-attribute
jitter: float = static_field(1e-06)
class-attribute
instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self
replace_meta(**kwargs: Any) -> Self
update_meta(**kwargs: Any) -> Self
replace_trainable(**kwargs: Dict[str, bool]) -> Self
Replace the trainability status of local nodes of the Module.
replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self
Replace the bijectors of local nodes of the Module.
constrain() -> Self
unconstrain() -> Self
stop_gradient() -> Self
trainables() -> Self
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution
Evaluate the Gaussian process at the given points.
The output of this function is a TensorFlow probability distribution from which the the latent function's mean and covariance can be evaluated and the distribution can be sampled.
Under the hood, __call__
is calling the objects predict
method. For this
reasons, classes inheriting the AbstractPrior
class, should not overwrite the
__call__
method and should instead define a predict
method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any
|
The arguments to pass to the GP's |
()
|
**kwargs |
Any
|
The keyword arguments to pass to the GP's |
{}
|
Returns
GaussianDistribution: A multivariate normal random variable representation
of the Gaussian process.
__mul__(other)
Combine the prior with a likelihood to form a posterior distribution.
The product of a prior and likelihood is proportional to the posterior distribution. By computing the product of a GP prior and a likelihood object, a posterior GP object will be returned. Mathematically, this can be described by:
where is the likelihood and is the prior.
Example:
>>> import gpjax as gpx
>>>
>>> meanf = gpx.mean_functions.Zero()
>>> kernel = gpx.kernels.RBF()
>>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
>>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=100)
>>>
>>> prior * likelihood
Returns
Posterior: The relevant GP posterior for the given prior and
likelihood. Special cases are accounted for where the model
is conjugate.
__rmul__(other)
Combine the prior with a likelihood to form a posterior distribution.
Reimplement the multiplication operator to allow for order-invariant product of a likelihood and a prior i.e., likelihood * prior.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
Likelihood
|
The likelihood distribution of the observed dataset. |
required |
Returns
Posterior: The relevant GP posterior for the given prior and
likelihood. Special cases are accounted for where the model
is conjugate.
predict(test_inputs: Num[Array, 'N D']) -> GaussianDistribution
Compute the predictive prior distribution for a given set of parameters. The output of this function is a function that computes a TFP distribution for a given set of inputs.
In the following example, we compute the predictive prior distribution
and then evaluate it on the interval :math:[0, 1]
:
Example:
>>> import gpjax as gpx
>>> import jax.numpy as jnp
>>>
>>> kernel = gpx.kernels.RBF()
>>> meanf = gpx.mean_functions.Zero()
>>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
>>>
>>> prior.predict(jnp.linspace(0, 1, 100))
Parameters:
Name | Type | Description | Default |
---|---|---|---|
test_inputs |
Float[Array, 'N D']
|
The inputs at which to evaluate the prior distribution. |
required |
Returns
GaussianDistribution: A multivariate normal random variable representation
of the Gaussian process.
sample_approx(num_samples: int, key: KeyArray, num_features: Optional[int] = 100) -> FunctionalSample
Approximate samples from the Gaussian process prior.
Build an approximate sample from the Gaussian process prior. This method provides a function that returns the evaluations of a sample across any given inputs.
In particular, we approximate the Gaussian processes' prior as the finite feature approximation where are features sampled from the Fourier feature decomposition of the model's kernel and are samples from a unit Gaussian.
A key property of such functional samples is that the same sample draw is evaluated for all queries. Consistency is a property that is prohibitively costly to ensure when sampling exactly from the GP prior, as the cost of exact sampling scales cubically with the size of the sample. In contrast, finite feature representations can be evaluated with constant cost regardless of the required number of queries.
In the following example, we build 10 such samples and then evaluate them over the interval :
For a prior
distribution, the following code snippet will
build and evaluate an approximate sample.
Example:
>>> import gpjax as gpx
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> key = jr.PRNGKey(123)
>>>
>>> meanf = gpx.mean_functions.Zero()
>>> kernel = gpx.kernels.RBF()
>>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
>>>
>>> sample_fn = prior.sample_approx(10, key)
>>> sample_fn(jnp.linspace(0, 1, 100).reshape(-1, 1))
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_samples |
int
|
The desired number of samples. |
required |
key |
KeyArray
|
The random seed used for the sample(s). |
required |
num_features |
int
|
The number of features used when approximating the kernel. |
100
|
Returns
FunctionalSample: A function representing an approximate sample from the
Gaussian process prior.
AbstractPosterior
dataclass
Bases: Module
, Generic[PriorType, Likelihood]
Abstract Gaussian process posterior.
The base GP posterior object conditioned on an observed dataset. All posterior objects should inherit from this class.
prior: AbstractPrior[MeanFunction, Kernel]
instance-attribute
likelihood: Likelihood
instance-attribute
jitter: float = static_field(1e-06)
class-attribute
instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self
replace_meta(**kwargs: Any) -> Self
update_meta(**kwargs: Any) -> Self
replace_trainable(**kwargs: Dict[str, bool]) -> Self
Replace the trainability status of local nodes of the Module.
replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self
Replace the bijectors of local nodes of the Module.
constrain() -> Self
unconstrain() -> Self
stop_gradient() -> Self
trainables() -> Self
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution
Evaluate the Gaussian process posterior at the given points.
The output of this function is a TFP distribution from which the the latent function's mean and covariance can be evaluated and the distribution can be sampled.
Under the hood, __call__
is calling the objects predict
method. For this
reasons, classes inheriting the AbstractPrior
class, should not overwrite the
__call__
method and should instead define a predict
method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any
|
The arguments to pass to the GP's |
()
|
**kwargs |
Any
|
The keyword arguments to pass to the GP's |
{}
|
Returns
GaussianDistribution: A multivariate normal random variable representation
of the Gaussian process.
predict(*args: Any, **kwargs: Any) -> GaussianDistribution
abstractmethod
Compute the latent function's multivariate normal distribution for a
given set of parameters. For any class inheriting the AbstractPrior
class,
this method must be implemented.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any
|
Arguments to the predict method. |
()
|
**kwargs |
Any
|
Keyword arguments to the predict method. |
{}
|
Returns
GaussianDistribution: A multivariate normal random variable representation
of the Gaussian process.
ConjugatePosterior
dataclass
Bases: AbstractPosterior[PriorType, GaussianLikelihood]
A Conjuate Gaussian process posterior object.
A Gaussian process posterior distribution when the constituent likelihood function is a Gaussian distribution. In such cases, the latent function values can be analytically integrated out of the posterior distribution. As such, many computational operations can be simplified; something we make use of in this object.
For a Gaussian process prior and a Gaussian likelihood where , the predictive posterior distribution at a set of inputs is given by
where
Example
prior: AbstractPrior[MeanFunction, Kernel]
instance-attribute
likelihood: Likelihood
instance-attribute
jitter: float = static_field(1e-06)
class-attribute
instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self
replace_meta(**kwargs: Any) -> Self
update_meta(**kwargs: Any) -> Self
replace_trainable(**kwargs: Dict[str, bool]) -> Self
Replace the trainability status of local nodes of the Module.
replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self
Replace the bijectors of local nodes of the Module.
constrain() -> Self
unconstrain() -> Self
stop_gradient() -> Self
trainables() -> Self
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution
Evaluate the Gaussian process posterior at the given points.
The output of this function is a TFP distribution from which the the latent function's mean and covariance can be evaluated and the distribution can be sampled.
Under the hood, __call__
is calling the objects predict
method. For this
reasons, classes inheriting the AbstractPrior
class, should not overwrite the
__call__
method and should instead define a predict
method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any
|
The arguments to pass to the GP's |
()
|
**kwargs |
Any
|
The keyword arguments to pass to the GP's |
{}
|
Returns
GaussianDistribution: A multivariate normal random variable representation
of the Gaussian process.
predict(test_inputs: Num[Array, 'N D'], train_data: Dataset) -> GaussianDistribution
Query the predictive posterior distribution.
Conditional on a training data set, compute the GP's posterior predictive distribution for a given set of parameters. The returned function can be evaluated at a set of test inputs to compute the corresponding predictive density.
The predictive distribution of a conjugate GP is given by $$ p(\mathbf{f}^{\star}\mid \mathbf{y}) & = \int p(\mathbf{f}^{\star} \mathbf{f} \mid \mathbf{y})\ & =\mathcal{N}(\mathbf{f}^{\star} \boldsymbol{\mu}{\mid \mathbf{y}}, \boldsymbol{\Sigma}} $$ where $$ \boldsymbol{\mu}{\mid \mathbf{y}} & = k(\mathbf{x}^{\star}, \mathbf{x})\left(k(\mathbf{x}, \mathbf{x}')+\sigma^2\mathbf{I}_n\right)^{-1}\mathbf{y} \ \boldsymbol{\Sigma}} & =k(\mathbf{x}^{\star}, \mathbf{x}^{\star\prime}) -k(\mathbf{x}^{\star}, \mathbf{x})\left( k(\mathbf{x}, \mathbf{x}') + \sigma^2\mathbf{I}_n \right)^{-1}k(\mathbf{x}, \mathbf{x}^{\star}). $$
The conditioning set is a GPJax Dataset
object, whilst predictions
are made on a regular Jax array.
Example
For a posterior
distribution, the following code snippet will
evaluate the predictive distribution.
>>> import gpjax as gpx
>>> import jax.numpy as jnp
>>>
>>> xtrain = jnp.linspace(0, 1).reshape(-1, 1)
>>> ytrain = jnp.sin(xtrain)
>>> D = gpx.Dataset(X=xtrain, y=ytrain)
>>> xtest = jnp.linspace(0, 1).reshape(-1, 1)
>>>
>>> prior = gpx.gps.Prior(mean_function = gpx.mean_functions.Zero(), kernel = gpx.kernels.RBF())
>>> posterior = prior * gpx.likelihoods.Gaussian(num_datapoints = D.n)
>>> predictive_dist = posterior(xtest, D)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
test_inputs |
Num[Array, 'N D']
|
A Jax array of test inputs at which the predictive distribution is evaluated. |
required |
train_data |
Dataset
|
A |
required |
Returns
GaussianDistribution: A function that accepts an input array and
returns the predictive distribution as a `GaussianDistribution`.
sample_approx(num_samples: int, train_data: Dataset, key: KeyArray, num_features: Optional[int] = 100) -> FunctionalSample
Draw approximate samples from the Gaussian process posterior.
Build an approximate sample from the Gaussian process posterior. This method provides a function that returns the evaluations of a sample across any given inputs.
Unlike when building approximate samples from a Gaussian process prior, decompositions based on Fourier features alone rarely give accurate samples. Therefore, we must also include an additional set of features (known as canonical features) to better model the transition from Gaussian process prior to Gaussian process posterior. For more details see Wilson et. al. (2020).
In particular, we approximate the Gaussian processes' posterior as the finite feature approximation where are m features sampled from the Fourier feature decomposition of the model's kernel and are N canonical features. The Fourier weights are samples from a unit Gaussian. See Wilson et. al. (2020) for expressions for the canonical weights .
A key property of such functional samples is that the same sample draw is evaluated for all queries. Consistency is a property that is prohibitively costly to ensure when sampling exactly from the GP prior, as the cost of exact sampling scales cubically with the size of the sample. In contrast, finite feature representations can be evaluated with constant cost regardless of the required number of queries.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_samples |
int
|
The desired number of samples. |
required |
key |
KeyArray
|
The random seed used for the sample(s). |
required |
num_features |
int
|
The number of features used when approximating the kernel. |
100
|
Returns
FunctionalSample: A function representing an approximate sample from the Gaussian
process prior.
NonConjugatePosterior
dataclass
Bases: AbstractPosterior[PriorType, NonGaussianLikelihood]
A non-conjugate Gaussian process posterior object.
A Gaussian process posterior object for models where the likelihood is
non-Gaussian. Unlike the ConjugatePosterior
object, the
NonConjugatePosterior
object does not provide an exact marginal
log-likelihood function. Instead, the NonConjugatePosterior
object
represents the posterior distributions as a function of the model's
hyperparameters and the latent function. Markov chain Monte Carlo,
variational inference, or Laplace approximations can then be used to sample
from, or optimise an approximation to, the posterior distribution.
prior: AbstractPrior[MeanFunction, Kernel]
instance-attribute
likelihood: Likelihood
instance-attribute
jitter: float = static_field(1e-06)
class-attribute
instance-attribute
latent: Union[Float[Array, 'N 1'], None] = param_field(None)
class-attribute
instance-attribute
key: KeyArray = static_field(PRNGKey(42))
class-attribute
instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self
replace_meta(**kwargs: Any) -> Self
update_meta(**kwargs: Any) -> Self
replace_trainable(**kwargs: Dict[str, bool]) -> Self
Replace the trainability status of local nodes of the Module.
replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self
Replace the bijectors of local nodes of the Module.
constrain() -> Self
unconstrain() -> Self
stop_gradient() -> Self
trainables() -> Self
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution
Evaluate the Gaussian process posterior at the given points.
The output of this function is a TFP distribution from which the the latent function's mean and covariance can be evaluated and the distribution can be sampled.
Under the hood, __call__
is calling the objects predict
method. For this
reasons, classes inheriting the AbstractPrior
class, should not overwrite the
__call__
method and should instead define a predict
method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any
|
The arguments to pass to the GP's |
()
|
**kwargs |
Any
|
The keyword arguments to pass to the GP's |
{}
|
Returns
GaussianDistribution: A multivariate normal random variable representation
of the Gaussian process.
__post_init__()
predict(test_inputs: Num[Array, 'N D'], train_data: Dataset) -> GaussianDistribution
Query the predictive posterior distribution.
Conditional on a set of training data, compute the GP's posterior predictive distribution for a given set of parameters. The returned function can be evaluated at a set of test inputs to compute the corresponding predictive density. Note, to gain predictions on the scale of the original data, the returned distribution will need to be transformed through the likelihood function's inverse link function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
train_data |
Dataset
|
A |
required |
Returns
GaussianDistribution: A function that accepts an
input array and returns the predictive distribution as
a `dx.Distribution`.
construct_posterior(prior, likelihood)
Utility function for constructing a posterior object from a prior and likelihood. The function will automatically select the correct posterior object based on the likelihood.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prior |
Prior
|
The Prior distribution. |
required |
likelihood |
AbstractLikelihood
|
The likelihood that represents our beliefs around the distribution of the data. |
required |
Returns
AbstractPosterior: A posterior distribution. If the likelihood is
Gaussian, then a `ConjugatePosterior` will be returned. Otherwise,
a `NonConjugatePosterior` will be returned.