Skip to content

GPs

gpjax.gps

Kernel = TypeVar('Kernel', bound=AbstractKernel) module-attribute
MeanFunction = TypeVar('MeanFunction', bound=AbstractMeanFunction) module-attribute
Likelihood = TypeVar('Likelihood', bound=AbstractLikelihood) module-attribute
NonGaussianLikelihood = TypeVar('NonGaussianLikelihood', bound=NonGaussian) module-attribute
GaussianLikelihood = TypeVar('GaussianLikelihood', bound=Gaussian) module-attribute
PriorType = TypeVar('PriorType', bound=AbstractPrior) module-attribute
__all__ = ['AbstractPrior', 'Prior', 'AbstractPosterior', 'ConjugatePosterior', 'NonConjugatePosterior', 'construct_posterior'] module-attribute
AbstractPrior dataclass

Bases: Module, Generic[MeanFunction, Kernel]

Abstract Gaussian process prior.

kernel: Kernel instance-attribute
mean_function: MeanFunction instance-attribute
jitter: float = static_field(1e-06) class-attribute instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self

Replace the values of the fields of the object.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_meta(**kwargs: Any) -> Self

Replace the metadata of the fields.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the metadata of the fields of the object.

{}
Returns
Module: with the metadata of the fields replaced.
update_meta(**kwargs: Any) -> Self

Update the metadata of the fields. The metadata must already exist.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_trainable(**kwargs: Dict[str, bool]) -> Self

Replace the trainability status of local nodes of the Module.

replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self

Replace the bijectors of local nodes of the Module.

constrain() -> Self

Transform model parameters to the constrained space according to their defined bijectors.

Returns
Module: transformed to the constrained space.
unconstrain() -> Self

Transform model parameters to the unconstrained space according to their defined bijectors.

Returns
Module: transformed to the unconstrained space.
stop_gradient() -> Self

Stop gradients flowing through the Module.

Returns
Module: with gradients stopped.
trainables() -> Self
__init__(kernel: Kernel, mean_function: MeanFunction, jitter: float = static_field(1e-06)) -> None
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution

Evaluate the Gaussian process at the given points.

The output of this function is a TensorFlow probability distribution from which the the latent function's mean and covariance can be evaluated and the distribution can be sampled.

Under the hood, __call__ is calling the objects predict method. For this reasons, classes inheriting the AbstractPrior class, should not overwrite the __call__ method and should instead define a predict method.

Parameters:

Name Type Description Default
*args Any

The arguments to pass to the GP's predict method.

()
**kwargs Any

The keyword arguments to pass to the GP's predict method.

{}
Returns
GaussianDistribution: A multivariate normal random variable representation
    of the Gaussian process.
predict(*args: Any, **kwargs: Any) -> GaussianDistribution abstractmethod

Evaluate the predictive distribution.

Compute the latent function's multivariate normal distribution for a given set of parameters. For any class inheriting the AbstractPrior class, this method must be implemented.

Parameters:

Name Type Description Default
*args Any

Arguments to the predict method.

()
**kwargs Any

Keyword arguments to the predict method.

{}
Returns
GaussianDistribution: A multivariate normal random variable representation
    of the Gaussian process.
Prior dataclass

Bases: AbstractPrior[MeanFunction, Kernel]

A Gaussian process prior object.

The GP is parameterised by a mean and kernel function.

A Gaussian process prior parameterised by a mean function m(β‹…)m(\cdot) and a kernel function k(β‹…,β‹…)k(\cdot, \cdot) is given by p(f(β‹…))=GP(m(β‹…),k(β‹…,β‹…))p(f(\cdot)) = \mathcal{GP}(m(\cdot), k(\cdot, \cdot)).

To invoke a Prior distribution, a kernel and mean function must be specified.

Example:

    >>> import gpjax as gpx

    >>> kernel = gpx.kernels.RBF()
    >>> meanf = gpx.mean_functions.Zero()
    >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)

kernel: Kernel instance-attribute
mean_function: MeanFunction instance-attribute
jitter: float = static_field(1e-06) class-attribute instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self

Replace the values of the fields of the object.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_meta(**kwargs: Any) -> Self

Replace the metadata of the fields.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the metadata of the fields of the object.

{}
Returns
Module: with the metadata of the fields replaced.
update_meta(**kwargs: Any) -> Self

Update the metadata of the fields. The metadata must already exist.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_trainable(**kwargs: Dict[str, bool]) -> Self

Replace the trainability status of local nodes of the Module.

replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self

Replace the bijectors of local nodes of the Module.

constrain() -> Self

Transform model parameters to the constrained space according to their defined bijectors.

Returns
Module: transformed to the constrained space.
unconstrain() -> Self

Transform model parameters to the unconstrained space according to their defined bijectors.

Returns
Module: transformed to the unconstrained space.
stop_gradient() -> Self

Stop gradients flowing through the Module.

Returns
Module: with gradients stopped.
trainables() -> Self
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution

Evaluate the Gaussian process at the given points.

The output of this function is a TensorFlow probability distribution from which the the latent function's mean and covariance can be evaluated and the distribution can be sampled.

Under the hood, __call__ is calling the objects predict method. For this reasons, classes inheriting the AbstractPrior class, should not overwrite the __call__ method and should instead define a predict method.

Parameters:

Name Type Description Default
*args Any

The arguments to pass to the GP's predict method.

()
**kwargs Any

The keyword arguments to pass to the GP's predict method.

{}
Returns
GaussianDistribution: A multivariate normal random variable representation
    of the Gaussian process.
__init__(kernel: Kernel, mean_function: MeanFunction, jitter: float = static_field(1e-06)) -> None
__mul__(other)

Combine the prior with a likelihood to form a posterior distribution.

The product of a prior and likelihood is proportional to the posterior distribution. By computing the product of a GP prior and a likelihood object, a posterior GP object will be returned. Mathematically, this can be described by:

p(f(β‹…)∣y)∝p(y∣f(β‹…))p(f(β‹…)), p(f(\cdot) \mid y) \propto p(y \mid f(\cdot))p(f(\cdot)),

where p(y∣f(β‹…))p(y | f(\cdot)) is the likelihood and p(f(β‹…))p(f(\cdot)) is the prior.

Example:

    >>> import gpjax as gpx
    >>>
    >>> meanf = gpx.mean_functions.Zero()
    >>> kernel = gpx.kernels.RBF()
    >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
    >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=100)
    >>>
    >>> prior * likelihood
Args: other (Likelihood): The likelihood distribution of the observed dataset.

Returns
Posterior: The relevant GP posterior for the given prior and
    likelihood. Special cases are accounted for where the model
    is conjugate.
__rmul__(other)

Combine the prior with a likelihood to form a posterior distribution.

Reimplement the multiplication operator to allow for order-invariant product of a likelihood and a prior i.e., likelihood * prior.

Parameters:

Name Type Description Default
other Likelihood

The likelihood distribution of the observed dataset.

required
Returns
Posterior: The relevant GP posterior for the given prior and
    likelihood. Special cases are accounted for where the model
    is conjugate.
predict(test_inputs: Num[Array, 'N D']) -> GaussianDistribution

Compute the predictive prior distribution for a given set of parameters. The output of this function is a function that computes a TFP distribution for a given set of inputs.

In the following example, we compute the predictive prior distribution and then evaluate it on the interval :math:[0, 1]:

Example:

    >>> import gpjax as gpx
    >>> import jax.numpy as jnp
    >>>
    >>> kernel = gpx.kernels.RBF()
    >>> meanf = gpx.mean_functions.Zero()
    >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
    >>>
    >>> prior.predict(jnp.linspace(0, 1, 100))

Parameters:

Name Type Description Default
test_inputs Float[Array, 'N D']

The inputs at which to evaluate the prior distribution.

required
Returns
GaussianDistribution: A multivariate normal random variable representation
    of the Gaussian process.
sample_approx(num_samples: int, key: KeyArray, num_features: Optional[int] = 100) -> FunctionalSample

Approximate samples from the Gaussian process prior.

Build an approximate sample from the Gaussian process prior. This method provides a function that returns the evaluations of a sample across any given inputs.

In particular, we approximate the Gaussian processes' prior as the finite feature approximation f^(x)=βˆ‘i=1mΟ•i(x)ΞΈi\hat{f}(x) = \sum_{i=1}^m\phi_i(x)\theta_i where Ο•i\phi_i are mm features sampled from the Fourier feature decomposition of the model's kernel and ΞΈi\theta_i are samples from a unit Gaussian.

A key property of such functional samples is that the same sample draw is evaluated for all queries. Consistency is a property that is prohibitively costly to ensure when sampling exactly from the GP prior, as the cost of exact sampling scales cubically with the size of the sample. In contrast, finite feature representations can be evaluated with constant cost regardless of the required number of queries.

In the following example, we build 10 such samples and then evaluate them over the interval [0,1][0, 1]:

For a prior distribution, the following code snippet will build and evaluate an approximate sample.

Example:

    >>> import gpjax as gpx
    >>> import jax.numpy as jnp
    >>> import jax.random as jr
    >>> key = jr.key(123)
    >>>
    >>> meanf = gpx.mean_functions.Zero()
    >>> kernel = gpx.kernels.RBF()
    >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
    >>>
    >>> sample_fn = prior.sample_approx(10, key)
    >>> sample_fn(jnp.linspace(0, 1, 100).reshape(-1, 1))

Parameters:

Name Type Description Default
num_samples int

The desired number of samples.

required
key KeyArray

The random seed used for the sample(s).

required
num_features int

The number of features used when approximating the kernel.

100
Returns
FunctionalSample: A function representing an approximate sample from the
    Gaussian process prior.
AbstractPosterior dataclass

Bases: Module, Generic[PriorType, Likelihood]

Abstract Gaussian process posterior.

The base GP posterior object conditioned on an observed dataset. All posterior objects should inherit from this class.

prior: AbstractPrior[MeanFunction, Kernel] instance-attribute
likelihood: Likelihood instance-attribute
jitter: float = static_field(1e-06) class-attribute instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self

Replace the values of the fields of the object.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_meta(**kwargs: Any) -> Self

Replace the metadata of the fields.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the metadata of the fields of the object.

{}
Returns
Module: with the metadata of the fields replaced.
update_meta(**kwargs: Any) -> Self

Update the metadata of the fields. The metadata must already exist.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_trainable(**kwargs: Dict[str, bool]) -> Self

Replace the trainability status of local nodes of the Module.

replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self

Replace the bijectors of local nodes of the Module.

constrain() -> Self

Transform model parameters to the constrained space according to their defined bijectors.

Returns
Module: transformed to the constrained space.
unconstrain() -> Self

Transform model parameters to the unconstrained space according to their defined bijectors.

Returns
Module: transformed to the unconstrained space.
stop_gradient() -> Self

Stop gradients flowing through the Module.

Returns
Module: with gradients stopped.
trainables() -> Self
__init__(prior: AbstractPrior[MeanFunction, Kernel], likelihood: Likelihood, jitter: float = static_field(1e-06)) -> None
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution

Evaluate the Gaussian process posterior at the given points.

The output of this function is a TFP distribution from which the the latent function's mean and covariance can be evaluated and the distribution can be sampled.

Under the hood, __call__ is calling the objects predict method. For this reasons, classes inheriting the AbstractPrior class, should not overwrite the __call__ method and should instead define a predict method.

Parameters:

Name Type Description Default
*args Any

The arguments to pass to the GP's predict method.

()
**kwargs Any

The keyword arguments to pass to the GP's predict method.

{}
Returns
GaussianDistribution: A multivariate normal random variable representation
    of the Gaussian process.
predict(*args: Any, **kwargs: Any) -> GaussianDistribution abstractmethod

Compute the latent function's multivariate normal distribution for a given set of parameters. For any class inheriting the AbstractPrior class, this method must be implemented.

Parameters:

Name Type Description Default
*args Any

Arguments to the predict method.

()
**kwargs Any

Keyword arguments to the predict method.

{}
Returns
GaussianDistribution: A multivariate normal random variable representation
    of the Gaussian process.
ConjugatePosterior dataclass

Bases: AbstractPosterior[PriorType, GaussianLikelihood]

A Conjuate Gaussian process posterior object.

A Gaussian process posterior distribution when the constituent likelihood function is a Gaussian distribution. In such cases, the latent function values ff can be analytically integrated out of the posterior distribution. As such, many computational operations can be simplified; something we make use of in this object.

For a Gaussian process prior p(f)p(\mathbf{f}) and a Gaussian likelihood p(y∣f)=N(y∣f,Οƒ2))p(y | \mathbf{f}) = \mathcal{N}(y\mid \mathbf{f}, \sigma^2)) where f=f(x)\mathbf{f} = f(\mathbf{x}), the predictive posterior distribution at a set of inputs x\mathbf{x} is given by

p(fβ‹†βˆ£y)=∫p(f⋆,f∣y)=N(fβ‹†ΞΌβˆ£y,Σ∣y \begin{align} p(\mathbf{f}^{\star}\mid \mathbf{y}) & = \int p(\mathbf{f}^{\star}, \mathbf{f} \mid \mathbf{y})\\ & =\mathcal{N}(\mathbf{f}^{\star} \boldsymbol{\mu}_{\mid \mathbf{y}}, \boldsymbol{\Sigma}_{\mid \mathbf{y}} \end{align}

where

μ∣y=k(x⋆,x)(k(x,xβ€²)+Οƒ2In)βˆ’1yΣ∣y=k(x⋆,x⋆′)βˆ’k(x⋆,x)(k(x,xβ€²)+Οƒ2In)βˆ’1k(x,x⋆). \begin{align} \boldsymbol{\mu}_{\mid \mathbf{y}} & = k(\mathbf{x}^{\star}, \mathbf{x})\left(k(\mathbf{x}, \mathbf{x}')+\sigma^2\mathbf{I}_n\right)^{-1}\mathbf{y} \\ \boldsymbol{\Sigma}_{\mid \mathbf{y}} & =k(\mathbf{x}^{\star}, \mathbf{x}^{\star\prime}) -k(\mathbf{x}^{\star}, \mathbf{x})\left( k(\mathbf{x}, \mathbf{x}') + \sigma^2\mathbf{I}_n \right)^{-1}k(\mathbf{x}, \mathbf{x}^{\star}). \end{align}

Example
    >>> import gpjax as gpx
    >>> import jax.numpy as jnp

    >>> prior = gpx.gps.Prior(
            mean_function = gpx.mean_functions.Zero(),
            kernel = gpx.kernels.RBF()
        )
    >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=100)
    >>>
    >>> posterior = prior * likelihood
prior: AbstractPrior[MeanFunction, Kernel] instance-attribute
likelihood: Likelihood instance-attribute
jitter: float = static_field(1e-06) class-attribute instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self

Replace the values of the fields of the object.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_meta(**kwargs: Any) -> Self

Replace the metadata of the fields.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the metadata of the fields of the object.

{}
Returns
Module: with the metadata of the fields replaced.
update_meta(**kwargs: Any) -> Self

Update the metadata of the fields. The metadata must already exist.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_trainable(**kwargs: Dict[str, bool]) -> Self

Replace the trainability status of local nodes of the Module.

replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self

Replace the bijectors of local nodes of the Module.

constrain() -> Self

Transform model parameters to the constrained space according to their defined bijectors.

Returns
Module: transformed to the constrained space.
unconstrain() -> Self

Transform model parameters to the unconstrained space according to their defined bijectors.

Returns
Module: transformed to the unconstrained space.
stop_gradient() -> Self

Stop gradients flowing through the Module.

Returns
Module: with gradients stopped.
trainables() -> Self
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution

Evaluate the Gaussian process posterior at the given points.

The output of this function is a TFP distribution from which the the latent function's mean and covariance can be evaluated and the distribution can be sampled.

Under the hood, __call__ is calling the objects predict method. For this reasons, classes inheriting the AbstractPrior class, should not overwrite the __call__ method and should instead define a predict method.

Parameters:

Name Type Description Default
*args Any

The arguments to pass to the GP's predict method.

()
**kwargs Any

The keyword arguments to pass to the GP's predict method.

{}
Returns
GaussianDistribution: A multivariate normal random variable representation
    of the Gaussian process.
__init__(prior: AbstractPrior[MeanFunction, Kernel], likelihood: Likelihood, jitter: float = static_field(1e-06)) -> None
predict(test_inputs: Num[Array, 'N D'], train_data: Dataset) -> GaussianDistribution

Query the predictive posterior distribution.

Conditional on a training data set, compute the GP's posterior predictive distribution for a given set of parameters. The returned function can be evaluated at a set of test inputs to compute the corresponding predictive density.

The predictive distribution of a conjugate GP is given by $$ p(\mathbf{f}^{\star}\mid \mathbf{y}) & = \int p(\mathbf{f}^{\star} \mathbf{f} \mid \mathbf{y})\ & =\mathcal{N}(\mathbf{f}^{\star} \boldsymbol{\mu}{\mid \mathbf{y}}, \boldsymbol{\Sigma}} $$ where $$ \boldsymbol{\mu}{\mid \mathbf{y}} & = k(\mathbf{x}^{\star}, \mathbf{x})\left(k(\mathbf{x}, \mathbf{x}')+\sigma^2\mathbf{I}_n\right)^{-1}\mathbf{y} \ \boldsymbol{\Sigma}} & =k(\mathbf{x}^{\star}, \mathbf{x}^{\star\prime}) -k(\mathbf{x}^{\star}, \mathbf{x})\left( k(\mathbf{x}, \mathbf{x}') + \sigma^2\mathbf{I}_n \right)^{-1}k(\mathbf{x}, \mathbf{x}^{\star}). $$

The conditioning set is a GPJax Dataset object, whilst predictions are made on a regular Jax array.

Example

For a posterior distribution, the following code snippet will evaluate the predictive distribution.

    >>> import gpjax as gpx
    >>> import jax.numpy as jnp
    >>>
    >>> xtrain = jnp.linspace(0, 1).reshape(-1, 1)
    >>> ytrain = jnp.sin(xtrain)
    >>> D = gpx.Dataset(X=xtrain, y=ytrain)
    >>> xtest = jnp.linspace(0, 1).reshape(-1, 1)
    >>>
    >>> prior = gpx.gps.Prior(mean_function = gpx.mean_functions.Zero(), kernel = gpx.kernels.RBF())
    >>> posterior = prior * gpx.likelihoods.Gaussian(num_datapoints = D.n)
    >>> predictive_dist = posterior(xtest, D)

Parameters:

Name Type Description Default
test_inputs Num[Array, 'N D']

A Jax array of test inputs at which the predictive distribution is evaluated.

required
train_data Dataset

A gpx.Dataset object that contains the input and output data used for training dataset.

required
Returns
GaussianDistribution: A function that accepts an input array and
    returns the predictive distribution as a `GaussianDistribution`.
sample_approx(num_samples: int, train_data: Dataset, key: KeyArray, num_features: Optional[int] = 100) -> FunctionalSample

Draw approximate samples from the Gaussian process posterior.

Build an approximate sample from the Gaussian process posterior. This method provides a function that returns the evaluations of a sample across any given inputs.

Unlike when building approximate samples from a Gaussian process prior, decompositions based on Fourier features alone rarely give accurate samples. Therefore, we must also include an additional set of features (known as canonical features) to better model the transition from Gaussian process prior to Gaussian process posterior. For more details see Wilson et. al. (2020).

In particular, we approximate the Gaussian processes' posterior as the finite feature approximation f^(x)=βˆ‘i=1mΟ•i(x)ΞΈi+βˆ‘j=1Nvjk(.,xj)\hat{f}(x) = \sum_{i=1}^m \phi_i(x)\theta_i + \sum{j=1}^N v_jk(.,x_j) where Ο•i\phi_i are m features sampled from the Fourier feature decomposition of the model's kernel and k(.,xj)k(., x_j) are N canonical features. The Fourier weights ΞΈi\theta_i are samples from a unit Gaussian. See Wilson et. al. (2020) for expressions for the canonical weights vjv_j.

A key property of such functional samples is that the same sample draw is evaluated for all queries. Consistency is a property that is prohibitively costly to ensure when sampling exactly from the GP prior, as the cost of exact sampling scales cubically with the size of the sample. In contrast, finite feature representations can be evaluated with constant cost regardless of the required number of queries.

Parameters:

Name Type Description Default
num_samples int

The desired number of samples.

required
key KeyArray

The random seed used for the sample(s).

required
num_features int

The number of features used when approximating the kernel.

100
Returns
FunctionalSample: A function representing an approximate sample from the Gaussian
process prior.
NonConjugatePosterior dataclass

Bases: AbstractPosterior[PriorType, NonGaussianLikelihood]

A non-conjugate Gaussian process posterior object.

A Gaussian process posterior object for models where the likelihood is non-Gaussian. Unlike the ConjugatePosterior object, the NonConjugatePosterior object does not provide an exact marginal log-likelihood function. Instead, the NonConjugatePosterior object represents the posterior distributions as a function of the model's hyperparameters and the latent function. Markov chain Monte Carlo, variational inference, or Laplace approximations can then be used to sample from, or optimise an approximation to, the posterior distribution.

prior: AbstractPrior[MeanFunction, Kernel] instance-attribute
likelihood: Likelihood instance-attribute
jitter: float = static_field(1e-06) class-attribute instance-attribute
latent: Union[Float[Array, 'N 1'], None] = param_field(None) class-attribute instance-attribute
key: KeyArray = static_field(PRNGKey(42)) class-attribute instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self

Replace the values of the fields of the object.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_meta(**kwargs: Any) -> Self

Replace the metadata of the fields.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the metadata of the fields of the object.

{}
Returns
Module: with the metadata of the fields replaced.
update_meta(**kwargs: Any) -> Self

Update the metadata of the fields. The metadata must already exist.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_trainable(**kwargs: Dict[str, bool]) -> Self

Replace the trainability status of local nodes of the Module.

replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self

Replace the bijectors of local nodes of the Module.

constrain() -> Self

Transform model parameters to the constrained space according to their defined bijectors.

Returns
Module: transformed to the constrained space.
unconstrain() -> Self

Transform model parameters to the unconstrained space according to their defined bijectors.

Returns
Module: transformed to the unconstrained space.
stop_gradient() -> Self

Stop gradients flowing through the Module.

Returns
Module: with gradients stopped.
trainables() -> Self
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution

Evaluate the Gaussian process posterior at the given points.

The output of this function is a TFP distribution from which the the latent function's mean and covariance can be evaluated and the distribution can be sampled.

Under the hood, __call__ is calling the objects predict method. For this reasons, classes inheriting the AbstractPrior class, should not overwrite the __call__ method and should instead define a predict method.

Parameters:

Name Type Description Default
*args Any

The arguments to pass to the GP's predict method.

()
**kwargs Any

The keyword arguments to pass to the GP's predict method.

{}
Returns
GaussianDistribution: A multivariate normal random variable representation
    of the Gaussian process.
__init__(prior: AbstractPrior[MeanFunction, Kernel], likelihood: Likelihood, jitter: float = static_field(1e-06), latent: Union[Float[Array, 'N 1'], None] = param_field(None), key: KeyArray = static_field(PRNGKey(42))) -> None
__post_init__()
predict(test_inputs: Num[Array, 'N D'], train_data: Dataset) -> GaussianDistribution

Query the predictive posterior distribution.

Conditional on a set of training data, compute the GP's posterior predictive distribution for a given set of parameters. The returned function can be evaluated at a set of test inputs to compute the corresponding predictive density. Note, to gain predictions on the scale of the original data, the returned distribution will need to be transformed through the likelihood function's inverse link function.

Parameters:

Name Type Description Default
train_data Dataset

A gpx.Dataset object that contains the input and output data used for training dataset.

required
Returns
GaussianDistribution: A function that accepts an
    input array and returns the predictive distribution as
    a `dx.Distribution`.
construct_posterior(prior, likelihood)

Utility function for constructing a posterior object from a prior and likelihood. The function will automatically select the correct posterior object based on the likelihood.

Parameters:

Name Type Description Default
prior Prior

The Prior distribution.

required
likelihood AbstractLikelihood

The likelihood that represents our beliefs around the distribution of the data.

required
Returns
AbstractPosterior: A posterior distribution. If the likelihood is
    Gaussian, then a `ConjugatePosterior` will be returned. Otherwise,
    a `NonConjugatePosterior` will be returned.