Skip to content

Variational Families

gpjax.variational_families

__all__ = ['AbstractVariationalFamily', 'AbstractVariationalGaussian', 'VariationalGaussian', 'WhitenedVariationalGaussian', 'NaturalVariationalGaussian', 'ExpectationVariationalGaussian', 'CollapsedVariationalGaussian'] module-attribute
AbstractVariationalFamily dataclass

Bases: Module

Abstract base class used to represent families of distributions that can be used within variational inference.

posterior: AbstractPosterior instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self

Replace the values of the fields of the object.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_meta(**kwargs: Any) -> Self

Replace the metadata of the fields.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the metadata of the fields of the object.

{}
Returns
Module: with the metadata of the fields replaced.
update_meta(**kwargs: Any) -> Self

Update the metadata of the fields. The metadata must already exist.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_trainable(**kwargs: Dict[str, bool]) -> Self

Replace the trainability status of local nodes of the Module.

replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self

Replace the bijectors of local nodes of the Module.

constrain() -> Self

Transform model parameters to the constrained space according to their defined bijectors.

Returns
Module: transformed to the constrained space.
unconstrain() -> Self

Transform model parameters to the unconstrained space according to their defined bijectors.

Returns
Module: transformed to the unconstrained space.
stop_gradient() -> Self

Stop gradients flowing through the Module.

Returns
Module: with gradients stopped.
trainables() -> Self
__init__(posterior: AbstractPosterior) -> None
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution

Evaluate the variational family's density.

For a given set of parameters, compute the latent function's prediction under the variational approximation.

Parameters:

Name Type Description Default
*args Any

Arguments of the variational family's predict method.

()
**kwargs Any

Keyword arguments of the variational family's predict method.

{}
Returns
GaussianDistribution: The output of the variational family's `predict` method.
predict(*args: Any, **kwargs: Any) -> GaussianDistribution abstractmethod

Predict the GP's output given the input.

Parameters:

Name Type Description Default
*args Any

Arguments of the variational family's predict method.

()
**kwargs Any

Keyword arguments of the variational family's predict method.

{}
Returns
GaussianDistribution: The output of the variational family's ``predict`` method.
AbstractVariationalGaussian dataclass

Bases: AbstractVariationalFamily

The variational Gaussian family of probability distributions.

posterior: AbstractPosterior instance-attribute
inducing_inputs: Float[Array, 'N D'] instance-attribute
jitter: ScalarFloat = static_field(1e-06) class-attribute instance-attribute
num_inducing: int property

The number of inducing inputs.

__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self

Replace the values of the fields of the object.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_meta(**kwargs: Any) -> Self

Replace the metadata of the fields.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the metadata of the fields of the object.

{}
Returns
Module: with the metadata of the fields replaced.
update_meta(**kwargs: Any) -> Self

Update the metadata of the fields. The metadata must already exist.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_trainable(**kwargs: Dict[str, bool]) -> Self

Replace the trainability status of local nodes of the Module.

replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self

Replace the bijectors of local nodes of the Module.

constrain() -> Self

Transform model parameters to the constrained space according to their defined bijectors.

Returns
Module: transformed to the constrained space.
unconstrain() -> Self

Transform model parameters to the unconstrained space according to their defined bijectors.

Returns
Module: transformed to the unconstrained space.
stop_gradient() -> Self

Stop gradients flowing through the Module.

Returns
Module: with gradients stopped.
trainables() -> Self
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution

Evaluate the variational family's density.

For a given set of parameters, compute the latent function's prediction under the variational approximation.

Parameters:

Name Type Description Default
*args Any

Arguments of the variational family's predict method.

()
**kwargs Any

Keyword arguments of the variational family's predict method.

{}
Returns
GaussianDistribution: The output of the variational family's `predict` method.
predict(*args: Any, **kwargs: Any) -> GaussianDistribution abstractmethod

Predict the GP's output given the input.

Parameters:

Name Type Description Default
*args Any

Arguments of the variational family's predict method.

()
**kwargs Any

Keyword arguments of the variational family's predict method.

{}
Returns
GaussianDistribution: The output of the variational family's ``predict`` method.
__init__(posterior: AbstractPosterior, inducing_inputs: Float[Array, 'N D'], jitter: ScalarFloat = static_field(1e-06)) -> None
VariationalGaussian dataclass

Bases: AbstractVariationalGaussian

The variational Gaussian family of probability distributions.

The variational family is q(f(β‹…))=∫p(f(β‹…)∣u)q(u)duq(f(\cdot)) = \int p(f(\cdot)\mid u) q(u) \mathrm{d}u, where u=f(z)u = f(z) are the function values at the inducing inputs zz and the distribution over the inducing inputs is q(u)=N(ΞΌ,S)q(u) = \mathcal{N}(\mu, S). We parameterise this over ΞΌ\mu and sqrtsqrt with S=sqrtsqrt⊀S = sqrt sqrt^{\top}.

posterior: AbstractPosterior instance-attribute
inducing_inputs: Float[Array, 'N D'] instance-attribute
jitter: ScalarFloat = static_field(1e-06) class-attribute instance-attribute
num_inducing: int property

The number of inducing inputs.

variational_mean: Union[Float[Array, 'N 1'], None] = param_field(None) class-attribute instance-attribute
variational_root_covariance: Float[Array, 'N N'] = param_field(None, bijector=tfb.FillTriangular()) class-attribute instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self

Replace the values of the fields of the object.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_meta(**kwargs: Any) -> Self

Replace the metadata of the fields.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the metadata of the fields of the object.

{}
Returns
Module: with the metadata of the fields replaced.
update_meta(**kwargs: Any) -> Self

Update the metadata of the fields. The metadata must already exist.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_trainable(**kwargs: Dict[str, bool]) -> Self

Replace the trainability status of local nodes of the Module.

replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self

Replace the bijectors of local nodes of the Module.

constrain() -> Self

Transform model parameters to the constrained space according to their defined bijectors.

Returns
Module: transformed to the constrained space.
unconstrain() -> Self

Transform model parameters to the unconstrained space according to their defined bijectors.

Returns
Module: transformed to the unconstrained space.
stop_gradient() -> Self

Stop gradients flowing through the Module.

Returns
Module: with gradients stopped.
trainables() -> Self
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution

Evaluate the variational family's density.

For a given set of parameters, compute the latent function's prediction under the variational approximation.

Parameters:

Name Type Description Default
*args Any

Arguments of the variational family's predict method.

()
**kwargs Any

Keyword arguments of the variational family's predict method.

{}
Returns
GaussianDistribution: The output of the variational family's `predict` method.
__init__(posterior: AbstractPosterior, inducing_inputs: Float[Array, 'N D'], jitter: ScalarFloat = static_field(1e-06), variational_mean: Union[Float[Array, 'N 1'], None] = param_field(None), variational_root_covariance: Float[Array, 'N N'] = param_field(None, bijector=tfb.FillTriangular())) -> None
__post_init__() -> None
prior_kl() -> ScalarFloat

Compute the prior KL divergence.

Compute the KL-divergence between our variational approximation and the Gaussian process prior.

For this variational family, we have

KL⁑[q(f(β‹…))∣∣p(β‹…)]=KL⁑[q(u)∣∣p(u)]=KL⁑[N(ΞΌ,S)∣∣N(ΞΌz,Kzz)], \begin{align} \operatorname{KL}[q(f(\cdot))\mid\mid p(\cdot)] & = \operatorname{KL}[q(u)\mid\mid p(u)]\\ & = \operatorname{KL}[ \mathcal{N}(\mu, S) \mid\mid N(\mu z, \mathbf{K}_{zz}) ], \end{align}

where u=f(z)u = f(z) and zz are the inducing inputs.

Returns
 ScalarFloat: The KL-divergence between our variational
    approximation and the GP prior.
predict(test_inputs: Float[Array, 'N D']) -> GaussianDistribution

Compute the predictive distribution of the GP at the test inputs t.

This is the integral q(f(t))=∫p(f(t)∣u)q(u)duq(f(t)) = \int p(f(t)\mid u) q(u) \mathrm{d}u, which can be computed in closed form as:

N(f(t);ΞΌt+KtzKzzβˆ’1(ΞΌβˆ’ΞΌz),Kttβˆ’KtzKzzβˆ’1Kzt+KtzKzzβˆ’1SKzzβˆ’1Kzt). \mathcal{N}\left(f(t); \mu t + \mathbf{K}_{tz} \mathbf{K}_{zz}^{-1} (\mu - \mu z), \mathbf{K}_{tt} - \mathbf{K}_{tz} \mathbf{K}_{zz}^{-1} \mathbf{K}_{zt} + \mathbf{K}_{tz} \mathbf{K}_{zz}^{-1} S \mathbf{K}_{zz}^{-1} \mathbf{K}_{zt}\right).

Parameters:

Name Type Description Default
test_inputs Float[Array, 'N D']

The test inputs at which we wish to make a prediction.

required
Returns
GaussianDistribution: The predictive distribution of the low-rank GP at
    the test inputs.
WhitenedVariationalGaussian dataclass

Bases: VariationalGaussian

The whitened variational Gaussian family of probability distributions.

The variational family is q(f(β‹…))=∫p(f(β‹…)∣u)q(u)duq(f(\cdot)) = \int p(f(\cdot)\mid u) q(u) \mathrm{d}u, where u=f(z)u = f(z) are the function values at the inducing inputs zz and the distribution over the inducing inputs is q(u)=N(LzΞΌ+mz,LzSLz⊀)q(u) = \mathcal{N}(Lz \mu + mz, Lz S Lz^{\top}). We parameterise this over ΞΌ\mu and sqrtsqrt with S=sqrtsqrt⊀S = sqrt sqrt^{\top}.

posterior: AbstractPosterior instance-attribute
inducing_inputs: Float[Array, 'N D'] instance-attribute
jitter: ScalarFloat = static_field(1e-06) class-attribute instance-attribute
num_inducing: int property

The number of inducing inputs.

variational_mean: Union[Float[Array, 'N 1'], None] = param_field(None) class-attribute instance-attribute
variational_root_covariance: Float[Array, 'N N'] = param_field(None, bijector=tfb.FillTriangular()) class-attribute instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self

Replace the values of the fields of the object.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_meta(**kwargs: Any) -> Self

Replace the metadata of the fields.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the metadata of the fields of the object.

{}
Returns
Module: with the metadata of the fields replaced.
update_meta(**kwargs: Any) -> Self

Update the metadata of the fields. The metadata must already exist.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_trainable(**kwargs: Dict[str, bool]) -> Self

Replace the trainability status of local nodes of the Module.

replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self

Replace the bijectors of local nodes of the Module.

constrain() -> Self

Transform model parameters to the constrained space according to their defined bijectors.

Returns
Module: transformed to the constrained space.
unconstrain() -> Self

Transform model parameters to the unconstrained space according to their defined bijectors.

Returns
Module: transformed to the unconstrained space.
stop_gradient() -> Self

Stop gradients flowing through the Module.

Returns
Module: with gradients stopped.
trainables() -> Self
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution

Evaluate the variational family's density.

For a given set of parameters, compute the latent function's prediction under the variational approximation.

Parameters:

Name Type Description Default
*args Any

Arguments of the variational family's predict method.

()
**kwargs Any

Keyword arguments of the variational family's predict method.

{}
Returns
GaussianDistribution: The output of the variational family's `predict` method.
__post_init__() -> None
__init__(posterior: AbstractPosterior, inducing_inputs: Float[Array, 'N D'], jitter: ScalarFloat = static_field(1e-06), variational_mean: Union[Float[Array, 'N 1'], None] = param_field(None), variational_root_covariance: Float[Array, 'N N'] = param_field(None, bijector=tfb.FillTriangular())) -> None
prior_kl() -> ScalarFloat

Compute the KL-divergence between our variational approximation and the Gaussian process prior.

For this variational family, we have

KL⁑[q(f(β‹…))∣∣p(β‹…)]=KL⁑[q(u)∣∣p(u)]=KL⁑[N(ΞΌ,S)∣∣N(0,I)]. \begin{align} \operatorname{KL}[q(f(\cdot))\mid\mid p(\cdot)] & = \operatorname{KL}[q(u)\mid\mid p(u)]\\ & = \operatorname{KL}[N(\mu , S)\mid\mid N(0, I)]. \end{align}

Returns
ScalarFloat: The KL-divergence between our variational
    approximation and the GP prior.
predict(test_inputs: Float[Array, 'N D']) -> GaussianDistribution

Compute the predictive distribution of the GP at the test inputs t.

This is the integral q(f(t)) = \int p(f(t)\midu) q(u) du, which can be computed in closed form as

N(f(t);ΞΌt+KtzLz⊀μ,Kttβˆ’KtzKzzβˆ’1Kzt+KtzLz⊀SLzβˆ’1Kzt). \mathcal{N}\left(f(t); \mu t + \mathbf{K}_{tz} \mathbf{L}z^{\top} \mu , \mathbf{K}_{tt} - \mathbf{K}_{tz} \mathbf{K}_{zz}^{-1} \mathbf{K}_{zt} + \mathbf{K}_{tz} \mathbf{L}z^{\top} S \mathbf{L}z^{-1} \mathbf{K}_{zt} \right).

Parameters:

Name Type Description Default
test_inputs Float[Array, 'N D']

The test inputs at which we wish to make a prediction.

required
Returns
GaussianDistribution: The predictive distribution of the low-rank GP at
    the test inputs.
NaturalVariationalGaussian dataclass

Bases: AbstractVariationalGaussian

The natural variational Gaussian family of probability distributions.

The variational family is q(f(β‹…))=∫p(f(β‹…)∣u)q(u)duq(f(\cdot)) = \int p(f(\cdot)\mid u) q(u) \mathrm{d}u, where u=f(z)u = f(z) are the function values at the inducing inputs zz and the distribution over the inducing inputs is q(u)=N(ΞΌ,S)q(u) = N(\mu, S). Expressing the variational distribution, in the form of the exponential family, q(u)=exp(θ⊀T(u)βˆ’a(ΞΈ))q(u) = exp(\theta^{\top} T(u) - a(\theta)), gives rise to the natural parameterisation ΞΈ=(ΞΈ1,ΞΈ2)=(Sβˆ’1ΞΌ,βˆ’Sβˆ’1/2)\theta = (\theta_{1}, \theta_{2}) = (S^{-1}\mu, -S^{-1}/2), to perform model inference, where T(u)=[u,uu⊀]T(u) = [u, uu^{\top}] are the sufficient statistics.

posterior: AbstractPosterior instance-attribute
inducing_inputs: Float[Array, 'N D'] instance-attribute
jitter: ScalarFloat = static_field(1e-06) class-attribute instance-attribute
num_inducing: int property

The number of inducing inputs.

natural_vector: Float[Array, 'M 1'] = None class-attribute instance-attribute
natural_matrix: Float[Array, 'M M'] = None class-attribute instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self

Replace the values of the fields of the object.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_meta(**kwargs: Any) -> Self

Replace the metadata of the fields.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the metadata of the fields of the object.

{}
Returns
Module: with the metadata of the fields replaced.
update_meta(**kwargs: Any) -> Self

Update the metadata of the fields. The metadata must already exist.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_trainable(**kwargs: Dict[str, bool]) -> Self

Replace the trainability status of local nodes of the Module.

replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self

Replace the bijectors of local nodes of the Module.

constrain() -> Self

Transform model parameters to the constrained space according to their defined bijectors.

Returns
Module: transformed to the constrained space.
unconstrain() -> Self

Transform model parameters to the unconstrained space according to their defined bijectors.

Returns
Module: transformed to the unconstrained space.
stop_gradient() -> Self

Stop gradients flowing through the Module.

Returns
Module: with gradients stopped.
trainables() -> Self
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution

Evaluate the variational family's density.

For a given set of parameters, compute the latent function's prediction under the variational approximation.

Parameters:

Name Type Description Default
*args Any

Arguments of the variational family's predict method.

()
**kwargs Any

Keyword arguments of the variational family's predict method.

{}
Returns
GaussianDistribution: The output of the variational family's `predict` method.
__init__(posterior: AbstractPosterior, inducing_inputs: Float[Array, 'N D'], jitter: ScalarFloat = static_field(1e-06), natural_vector: Float[Array, 'M 1'] = None, natural_matrix: Float[Array, 'M M'] = None) -> None
__post_init__()
prior_kl() -> ScalarFloat

Compute the KL-divergence between our current variational approximation and the Gaussian process prior.

For this variational family, we have

KL⁑[q(f(β‹…))∣∣p(β‹…)]=KL⁑[q(u)∣∣p(u)]=KL⁑[N(ΞΌ,S)∣∣N(mz,Kzz)], \begin{align} \operatorname{KL}[q(f(\cdot))\mid\mid p(\cdot)] & = \operatorname{KL}[q(u)\mid\mid p(u)] \\ & = \operatorname{KL}[N(\mu, S)\mid\mid N(mz, \mathbf{K}_{zz})], \end{align}

with $\mu$ and $S$ computed from the natural parameterisation $\theta = (S^{-1}\mu , -S^{-1}/2)$.

Returns
ScalarFloat: The KL-divergence between our variational approximation and
    the GP prior.
predict(test_inputs: Float[Array, 'N D']) -> GaussianDistribution

Compute the predictive distribution of the GP at the test inputs $t$.

This is the integral q(f(t))=∫p(f(t)∣u)q(u)duq(f(t)) = \int p(f(t)\mid u) q(u) \mathrm{d}u, which can be computed in closed form as

N(f(t);ΞΌt+KtzKzzβˆ’1(ΞΌβˆ’ΞΌz),Kttβˆ’KtzKzzβˆ’1Kzt+KtzKzzβˆ’1SKzzβˆ’1Kzt), \mathcal{N}\left(f(t); \mu t + \mathbf{K}_{tz} \mathbf{K}_{zz}^{-1} (\mu - \mu z), \mathbf{K}_{tt} - \mathbf{K}_{tz} \mathbf{K}_{zz}^{-1} \mathbf{K}_{zt} + \mathbf{K}_{tz} \mathbf{K}_{zz}^{-1} S \mathbf{K}_{zz}^{-1} \mathbf{K}_{zt} \right),

with ΞΌ\mu and SS computed from the natural parameterisation ΞΈ=(Sβˆ’1ΞΌ,βˆ’Sβˆ’1/2)\theta = (S^{-1}\mu , -S^{-1}/2).

Returns
GaussianDistribution: A function that accepts a set of test points and will
    return the predictive distribution at those points.
ExpectationVariationalGaussian dataclass

Bases: AbstractVariationalGaussian

The natural variational Gaussian family of probability distributions.

The variational family is q(f(β‹…))=∫p(f(β‹…)∣u)q(u)duq(f(\cdot)) = \int p(f(\cdot)\mid u) q(u) \mathrm{d}u, where u=f(z)u = f(z) are the function values at the inducing inputs zz and the distribution over the inducing inputs is q(u)=N(ΞΌ,S)q(u) = \mathcal{N}(\mu, S). Expressing the variational distribution, in the form of the exponential family, q(u)=exp(θ⊀T(u)βˆ’a(ΞΈ))q(u) = exp(\theta^{\top} T(u) - a(\theta)), gives rise to the natural parameterisation ΞΈ=(ΞΈ1,ΞΈ2)=(Sβˆ’1ΞΌ,βˆ’Sβˆ’1/2)\theta = (\theta_{1}, \theta_{2}) = (S^{-1}\mu , -S^{-1}/2) and sufficient statistics T(u)=[u,uu⊀]T(u) = [u, uu^{\top}]. The expectation parameters are given by Ξ½=∫T(u)q(u)du\nu = \int T(u) q(u) \mathrm{d}u. This gives a parameterisation, Ξ½=(Ξ½1,Ξ½2)=(ΞΌ,S+uu⊀)\nu = (\nu_{1}, \nu_{2}) = (\mu , S + uu^{\top}) to perform model inference over.

posterior: AbstractPosterior instance-attribute
inducing_inputs: Float[Array, 'N D'] instance-attribute
jitter: ScalarFloat = static_field(1e-06) class-attribute instance-attribute
num_inducing: int property

The number of inducing inputs.

expectation_vector: Float[Array, 'M 1'] = None class-attribute instance-attribute
expectation_matrix: Float[Array, 'M M'] = None class-attribute instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self

Replace the values of the fields of the object.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_meta(**kwargs: Any) -> Self

Replace the metadata of the fields.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the metadata of the fields of the object.

{}
Returns
Module: with the metadata of the fields replaced.
update_meta(**kwargs: Any) -> Self

Update the metadata of the fields. The metadata must already exist.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_trainable(**kwargs: Dict[str, bool]) -> Self

Replace the trainability status of local nodes of the Module.

replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self

Replace the bijectors of local nodes of the Module.

constrain() -> Self

Transform model parameters to the constrained space according to their defined bijectors.

Returns
Module: transformed to the constrained space.
unconstrain() -> Self

Transform model parameters to the unconstrained space according to their defined bijectors.

Returns
Module: transformed to the unconstrained space.
stop_gradient() -> Self

Stop gradients flowing through the Module.

Returns
Module: with gradients stopped.
trainables() -> Self
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution

Evaluate the variational family's density.

For a given set of parameters, compute the latent function's prediction under the variational approximation.

Parameters:

Name Type Description Default
*args Any

Arguments of the variational family's predict method.

()
**kwargs Any

Keyword arguments of the variational family's predict method.

{}
Returns
GaussianDistribution: The output of the variational family's `predict` method.
__init__(posterior: AbstractPosterior, inducing_inputs: Float[Array, 'N D'], jitter: ScalarFloat = static_field(1e-06), expectation_vector: Float[Array, 'M 1'] = None, expectation_matrix: Float[Array, 'M M'] = None) -> None
__post_init__()
prior_kl() -> ScalarFloat

Evaluate the prior KL-divergence.

Compute the KL-divergence between our current variational approximation and the Gaussian process prior.

For this variational family, we have

KL⁑(q(f(β‹…))∣∣p(β‹…))=KL⁑(q(u)∣∣p(u))=KL⁑(N(ΞΌ,S)∣∣N(mz,Kzz)), \begin{align} \operatorname{KL}(q(f(\cdot))\mid\mid p(\cdot)) & = \operatorname{KL}(q(u)\mid\mid p(u)) \\ & =\operatorname{KL}(\mathcal{N}(\mu, S)\mid\mid \mathcal{N}(m_z, K_{zz})), \end{align}

where $\mu$ and $S$ are the expectation parameters of the variational distribution and $m_z$ and $K_{zz}$ are the mean and covariance of the prior distribution.

Returns
ScalarFloat: The KL-divergence between our variational approximation and
    the GP prior.
predict(test_inputs: Float[Array, 'N D']) -> GaussianDistribution

Evaluate the predictive distribution.

Compute the predictive distribution of the GP at the test inputs $t$.

This is the integral $q(f(t)) = \int p(f(t)\mid u)q(u)\mathrm{d}u$, which can be computed in closed form as which can be computed in closed form as

N(f(t);ΞΌt+KtzKzzβˆ’1(ΞΌβˆ’ΞΌz),Kttβˆ’KtzKzzβˆ’1Kzt+KtzKzzβˆ’1SKzzβˆ’1Kzt) \mathcal{N}(f(t); \mu_t + \mathbf{K}_{tz}\mathbf{K}_{zz}^{-1}(\mu - \mu_z), \mathbf{K}_{tt} - \mathbf{K}_{tz}\mathbf{K}_{zz}^{-1}\mathbf{K}_{zt} + \mathbf{K}_{tz}\mathbf{K}_{zz}^{-1}\mathbf{S} \mathbf{K}_{zz}^{-1}\mathbf{K}_{zt})

with $\mu$ and $S$ computed from the expectation parameterisation $\eta = (\mu, S + uu^\top)$.

Returns
GaussianDistribution: The predictive distribution of the GP at the
    test inputs $t$.
CollapsedVariationalGaussian dataclass

Bases: AbstractVariationalGaussian

Collapsed variational Gaussian.

Collapsed variational Gaussian family of probability distributions. The key reference is Titsias, (2009) - Variational Learning of Inducing Variables in Sparse Gaussian Processes.

posterior: AbstractPosterior instance-attribute
inducing_inputs: Float[Array, 'N D'] instance-attribute
jitter: ScalarFloat = static_field(1e-06) class-attribute instance-attribute
num_inducing: int property

The number of inducing inputs.

__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self

Replace the values of the fields of the object.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_meta(**kwargs: Any) -> Self

Replace the metadata of the fields.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the metadata of the fields of the object.

{}
Returns
Module: with the metadata of the fields replaced.
update_meta(**kwargs: Any) -> Self

Update the metadata of the fields. The metadata must already exist.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_trainable(**kwargs: Dict[str, bool]) -> Self

Replace the trainability status of local nodes of the Module.

replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self

Replace the bijectors of local nodes of the Module.

constrain() -> Self

Transform model parameters to the constrained space according to their defined bijectors.

Returns
Module: transformed to the constrained space.
unconstrain() -> Self

Transform model parameters to the unconstrained space according to their defined bijectors.

Returns
Module: transformed to the unconstrained space.
stop_gradient() -> Self

Stop gradients flowing through the Module.

Returns
Module: with gradients stopped.
trainables() -> Self
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution

Evaluate the variational family's density.

For a given set of parameters, compute the latent function's prediction under the variational approximation.

Parameters:

Name Type Description Default
*args Any

Arguments of the variational family's predict method.

()
**kwargs Any

Keyword arguments of the variational family's predict method.

{}
Returns
GaussianDistribution: The output of the variational family's `predict` method.
__init__(posterior: AbstractPosterior, inducing_inputs: Float[Array, 'N D'], jitter: ScalarFloat = static_field(1e-06)) -> None
__post_init__()
predict(test_inputs: Float[Array, 'N D'], train_data: Dataset) -> GaussianDistribution

Compute the predictive distribution of the GP at the test inputs.

Parameters:

Name Type Description Default
test_inputs Float[Array, 'N D']

The test inputs $t$ at which to make predictions.

required
train_data Dataset

The training data that was used to fit the GP.

required
Returns
GaussianDistribution: The predictive distribution of the collapsed
    variational Gaussian process at the test inputs $t$.