Variational Families
gpjax.variational_families
__all__ = ['AbstractVariationalFamily', 'AbstractVariationalGaussian', 'VariationalGaussian', 'WhitenedVariationalGaussian', 'NaturalVariationalGaussian', 'ExpectationVariationalGaussian', 'CollapsedVariationalGaussian']
module-attribute
AbstractVariationalFamily
dataclass
Bases: Module
Abstract base class used to represent families of distributions that can be used within variational inference.
posterior: AbstractPosterior
instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self
replace_meta(**kwargs: Any) -> Self
update_meta(**kwargs: Any) -> Self
replace_trainable(**kwargs: Dict[str, bool]) -> Self
Replace the trainability status of local nodes of the Module.
replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self
Replace the bijectors of local nodes of the Module.
constrain() -> Self
unconstrain() -> Self
stop_gradient() -> Self
trainables() -> Self
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution
Evaluate the variational family's density.
For a given set of parameters, compute the latent function's prediction under the variational approximation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any
|
Arguments of the variational family's |
()
|
**kwargs |
Any
|
Keyword arguments of the variational family's |
{}
|
Returns
GaussianDistribution: The output of the variational family's `predict` method.
predict(*args: Any, **kwargs: Any) -> GaussianDistribution
abstractmethod
Predict the GP's output given the input.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any
|
Arguments of the variational family's |
()
|
**kwargs |
Any
|
Keyword arguments of the variational family's
|
{}
|
Returns
GaussianDistribution: The output of the variational family's ``predict`` method.
AbstractVariationalGaussian
dataclass
Bases: AbstractVariationalFamily
The variational Gaussian family of probability distributions.
posterior: AbstractPosterior
instance-attribute
inducing_inputs: Float[Array, 'N D']
instance-attribute
jitter: ScalarFloat = static_field(1e-06)
class-attribute
instance-attribute
num_inducing: int
property
The number of inducing inputs.
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self
replace_meta(**kwargs: Any) -> Self
update_meta(**kwargs: Any) -> Self
replace_trainable(**kwargs: Dict[str, bool]) -> Self
Replace the trainability status of local nodes of the Module.
replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self
Replace the bijectors of local nodes of the Module.
constrain() -> Self
unconstrain() -> Self
stop_gradient() -> Self
trainables() -> Self
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution
Evaluate the variational family's density.
For a given set of parameters, compute the latent function's prediction under the variational approximation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any
|
Arguments of the variational family's |
()
|
**kwargs |
Any
|
Keyword arguments of the variational family's |
{}
|
Returns
GaussianDistribution: The output of the variational family's `predict` method.
predict(*args: Any, **kwargs: Any) -> GaussianDistribution
abstractmethod
Predict the GP's output given the input.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any
|
Arguments of the variational family's |
()
|
**kwargs |
Any
|
Keyword arguments of the variational family's
|
{}
|
Returns
GaussianDistribution: The output of the variational family's ``predict`` method.
VariationalGaussian
dataclass
Bases: AbstractVariationalGaussian
The variational Gaussian family of probability distributions.
The variational family is , where are the function values at the inducing inputs and the distribution over the inducing inputs is . We parameterise this over and with .
posterior: AbstractPosterior
instance-attribute
inducing_inputs: Float[Array, 'N D']
instance-attribute
jitter: ScalarFloat = static_field(1e-06)
class-attribute
instance-attribute
num_inducing: int
property
The number of inducing inputs.
variational_mean: Union[Float[Array, 'N 1'], None] = param_field(None)
class-attribute
instance-attribute
variational_root_covariance: Float[Array, 'N N'] = param_field(None, bijector=tfb.FillTriangular())
class-attribute
instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self
replace_meta(**kwargs: Any) -> Self
update_meta(**kwargs: Any) -> Self
replace_trainable(**kwargs: Dict[str, bool]) -> Self
Replace the trainability status of local nodes of the Module.
replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self
Replace the bijectors of local nodes of the Module.
constrain() -> Self
unconstrain() -> Self
stop_gradient() -> Self
trainables() -> Self
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution
Evaluate the variational family's density.
For a given set of parameters, compute the latent function's prediction under the variational approximation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any
|
Arguments of the variational family's |
()
|
**kwargs |
Any
|
Keyword arguments of the variational family's |
{}
|
Returns
GaussianDistribution: The output of the variational family's `predict` method.
__post_init__() -> None
prior_kl() -> ScalarFloat
predict(test_inputs: Float[Array, 'N D']) -> GaussianDistribution
Compute the predictive distribution of the GP at the test inputs t.
This is the integral , which can be computed in closed form as:
Parameters:
Name | Type | Description | Default |
---|---|---|---|
test_inputs |
Float[Array, 'N D']
|
The test inputs at which we wish to make a prediction. |
required |
Returns
GaussianDistribution: The predictive distribution of the low-rank GP at
the test inputs.
WhitenedVariationalGaussian
dataclass
Bases: VariationalGaussian
The whitened variational Gaussian family of probability distributions.
The variational family is , where are the function values at the inducing inputs and the distribution over the inducing inputs is . We parameterise this over and with .
posterior: AbstractPosterior
instance-attribute
inducing_inputs: Float[Array, 'N D']
instance-attribute
jitter: ScalarFloat = static_field(1e-06)
class-attribute
instance-attribute
num_inducing: int
property
The number of inducing inputs.
variational_mean: Union[Float[Array, 'N 1'], None] = param_field(None)
class-attribute
instance-attribute
variational_root_covariance: Float[Array, 'N N'] = param_field(None, bijector=tfb.FillTriangular())
class-attribute
instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self
replace_meta(**kwargs: Any) -> Self
update_meta(**kwargs: Any) -> Self
replace_trainable(**kwargs: Dict[str, bool]) -> Self
Replace the trainability status of local nodes of the Module.
replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self
Replace the bijectors of local nodes of the Module.
constrain() -> Self
unconstrain() -> Self
stop_gradient() -> Self
trainables() -> Self
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution
Evaluate the variational family's density.
For a given set of parameters, compute the latent function's prediction under the variational approximation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any
|
Arguments of the variational family's |
()
|
**kwargs |
Any
|
Keyword arguments of the variational family's |
{}
|
Returns
GaussianDistribution: The output of the variational family's `predict` method.
__post_init__() -> None
prior_kl() -> ScalarFloat
predict(test_inputs: Float[Array, 'N D']) -> GaussianDistribution
Compute the predictive distribution of the GP at the test inputs t.
This is the integral q(f(t)) = \int p(f(t)\midu) q(u) du, which can be computed in closed form as
Parameters:
Name | Type | Description | Default |
---|---|---|---|
test_inputs |
Float[Array, 'N D']
|
The test inputs at which we wish to make a prediction. |
required |
Returns
GaussianDistribution: The predictive distribution of the low-rank GP at
the test inputs.
NaturalVariationalGaussian
dataclass
Bases: AbstractVariationalGaussian
The natural variational Gaussian family of probability distributions.
The variational family is , where are the function values at the inducing inputs and the distribution over the inducing inputs is . Expressing the variational distribution, in the form of the exponential family, , gives rise to the natural parameterisation , to perform model inference, where are the sufficient statistics.
posterior: AbstractPosterior
instance-attribute
inducing_inputs: Float[Array, 'N D']
instance-attribute
jitter: ScalarFloat = static_field(1e-06)
class-attribute
instance-attribute
num_inducing: int
property
The number of inducing inputs.
natural_vector: Float[Array, 'M 1'] = None
class-attribute
instance-attribute
natural_matrix: Float[Array, 'M M'] = None
class-attribute
instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self
replace_meta(**kwargs: Any) -> Self
update_meta(**kwargs: Any) -> Self
replace_trainable(**kwargs: Dict[str, bool]) -> Self
Replace the trainability status of local nodes of the Module.
replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self
Replace the bijectors of local nodes of the Module.
constrain() -> Self
unconstrain() -> Self
stop_gradient() -> Self
trainables() -> Self
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution
Evaluate the variational family's density.
For a given set of parameters, compute the latent function's prediction under the variational approximation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any
|
Arguments of the variational family's |
()
|
**kwargs |
Any
|
Keyword arguments of the variational family's |
{}
|
Returns
GaussianDistribution: The output of the variational family's `predict` method.
__post_init__()
prior_kl() -> ScalarFloat
Compute the KL-divergence between our current variational approximation and the Gaussian process prior.
For this variational family, we have
with $\mu$ and $S$ computed from the natural parameterisation $\theta = (S^{-1}\mu , -S^{-1}/2)$.
Returns
ScalarFloat: The KL-divergence between our variational approximation and
the GP prior.
predict(test_inputs: Float[Array, 'N D']) -> GaussianDistribution
Compute the predictive distribution of the GP at the test inputs $t$.
This is the integral , which can be computed in closed form as
with and computed from the natural parameterisation .
Returns
GaussianDistribution: A function that accepts a set of test points and will
return the predictive distribution at those points.
ExpectationVariationalGaussian
dataclass
Bases: AbstractVariationalGaussian
The natural variational Gaussian family of probability distributions.
The variational family is , where are the function values at the inducing inputs and the distribution over the inducing inputs is . Expressing the variational distribution, in the form of the exponential family, , gives rise to the natural parameterisation and sufficient statistics . The expectation parameters are given by . This gives a parameterisation, to perform model inference over.
posterior: AbstractPosterior
instance-attribute
inducing_inputs: Float[Array, 'N D']
instance-attribute
jitter: ScalarFloat = static_field(1e-06)
class-attribute
instance-attribute
num_inducing: int
property
The number of inducing inputs.
expectation_vector: Float[Array, 'M 1'] = None
class-attribute
instance-attribute
expectation_matrix: Float[Array, 'M M'] = None
class-attribute
instance-attribute
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self
replace_meta(**kwargs: Any) -> Self
update_meta(**kwargs: Any) -> Self
replace_trainable(**kwargs: Dict[str, bool]) -> Self
Replace the trainability status of local nodes of the Module.
replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self
Replace the bijectors of local nodes of the Module.
constrain() -> Self
unconstrain() -> Self
stop_gradient() -> Self
trainables() -> Self
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution
Evaluate the variational family's density.
For a given set of parameters, compute the latent function's prediction under the variational approximation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any
|
Arguments of the variational family's |
()
|
**kwargs |
Any
|
Keyword arguments of the variational family's |
{}
|
Returns
GaussianDistribution: The output of the variational family's `predict` method.
__post_init__()
prior_kl() -> ScalarFloat
Evaluate the prior KL-divergence.
Compute the KL-divergence between our current variational approximation and the Gaussian process prior.
For this variational family, we have
where $\mu$ and $S$ are the expectation parameters of the variational distribution and $m_z$ and $K_{zz}$ are the mean and covariance of the prior distribution.
Returns
ScalarFloat: The KL-divergence between our variational approximation and
the GP prior.
predict(test_inputs: Float[Array, 'N D']) -> GaussianDistribution
Evaluate the predictive distribution.
Compute the predictive distribution of the GP at the test inputs $t$.
This is the integral $q(f(t)) = \int p(f(t)\mid u)q(u)\mathrm{d}u$, which can be computed in closed form as which can be computed in closed form as
with $\mu$ and $S$ computed from the expectation parameterisation $\eta = (\mu, S + uu^\top)$.
Returns
GaussianDistribution: The predictive distribution of the GP at the
test inputs $t$.
CollapsedVariationalGaussian
dataclass
Bases: AbstractVariationalGaussian
Collapsed variational Gaussian.
Collapsed variational Gaussian family of probability distributions. The key reference is Titsias, (2009) - Variational Learning of Inducing Variables in Sparse Gaussian Processes.
posterior: AbstractPosterior
instance-attribute
inducing_inputs: Float[Array, 'N D']
instance-attribute
jitter: ScalarFloat = static_field(1e-06)
class-attribute
instance-attribute
num_inducing: int
property
The number of inducing inputs.
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self
replace_meta(**kwargs: Any) -> Self
update_meta(**kwargs: Any) -> Self
replace_trainable(**kwargs: Dict[str, bool]) -> Self
Replace the trainability status of local nodes of the Module.
replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self
Replace the bijectors of local nodes of the Module.
constrain() -> Self
unconstrain() -> Self
stop_gradient() -> Self
trainables() -> Self
__call__(*args: Any, **kwargs: Any) -> GaussianDistribution
Evaluate the variational family's density.
For a given set of parameters, compute the latent function's prediction under the variational approximation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any
|
Arguments of the variational family's |
()
|
**kwargs |
Any
|
Keyword arguments of the variational family's |
{}
|
Returns
GaussianDistribution: The output of the variational family's `predict` method.
__post_init__()
predict(test_inputs: Float[Array, 'N D'], train_data: Dataset) -> GaussianDistribution
Compute the predictive distribution of the GP at the test inputs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
test_inputs |
Float[Array, 'N D']
|
The test inputs $t$ at which to make predictions. |
required |
train_data |
Dataset
|
The training data that was used to fit the GP. |
required |
Returns
GaussianDistribution: The predictive distribution of the collapsed
variational Gaussian process at the test inputs $t$.