Distributions
GaussianDistribution
GaussianDistribution(
loc: Optional[Float[Array, " N"]],
scale: Optional[LinearOperator],
validate_args=None,
)
Bases: Distribution
Multivariate Gaussian distribution for GP predictions.
This is the return type of all predict() methods in GPJax. It wraps a
mean vector and a covariance :class:~gpjax.linalg.operators.LinearOperator,
providing methods for sampling, computing log-probabilities, and evaluating
KL divergences.
The distribution is parameterised as
.. math::
p(\mathbf{x}) = \mathcal{N}(\mathbf{x}; \boldsymbol{\mu}, \mathbf{\Sigma})
where :math:\boldsymbol{\mu} is the loc (mean) vector and
:math:\mathbf{\Sigma} is represented by the scale
:class:~gpjax.linalg.operators.LinearOperator. The scale is
automatically annotated as positive semi-definite on construction.
Parameters
loc : Float[Array, " N"]
Mean vector of the distribution.
scale : LinearOperator
Covariance matrix represented as a
:class:~gpjax.linalg.operators.LinearOperator (e.g.
:class:~gpjax.linalg.operators.Dense or
:class:~gpjax.linalg.operators.Diagonal).
Examples
import jax.numpy as jnp from gpjax.distributions import GaussianDistribution from gpjax.linalg.operators import Dense mu = jnp.array([0.0, 1.0]) cov = Dense(jnp.eye(2)) dist = GaussianDistribution(loc=mu, scale=cov) dist.mean Array([0., 1.], dtype=float32) dist.variance Array([1., 1.], dtype=float32)
variance
property
Calculates the marginal variance (diagonal of the covariance).
covariance_matrix
property
Property alias for :meth:covariance.
sample
Draw samples from the distribution.
Generates samples via the reparameterisation trick:
.. math::
\mathbf{x} = \boldsymbol{\mu} + \mathbf{L}\mathbf{z},
\quad \mathbf{z} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})
where :math:\mathbf{L} is the lower Cholesky factor of the covariance.
Parameters
key : KeyArray
JAX PRNG key.
sample_shape : tuple of int, optional
Leading batch dimensions for the samples. Defaults to (),
returning a single sample.
Returns
Float[Array, "... N"]
Array of samples with shape (*sample_shape, N).
entropy
covariance
log_prob
Calculates the log pdf of the multivariate Gaussian.
.. math::
\log p(\mathbf{y}) = -\tfrac{1}{2}\bigl[
N\ln 2\pi + \ln|\mathbf{\Sigma}|
+ (\mathbf{y} - \boldsymbol{\mu})^\top
\mathbf{\Sigma}^{-1}
(\mathbf{y} - \boldsymbol{\mu})
\bigr]
Parameters
y : Float[Array, " N"] Point at which to evaluate the log-density.
Returns
ScalarFloat Log probability.