Multivariate Gaussian distribution for GP predictions.
This is the return type of all predict() methods in GPJax. It wraps a
mean vector and a covariance :class:~gpjax.linalg.operators.LinearOperator,
providing methods for sampling, computing log-probabilities, and evaluating
KL divergences.
where :math:\boldsymbol{\mu} is the loc (mean) vector and
:math:\mathbf{\Sigma} is represented by the scale
:class:~gpjax.linalg.operators.LinearOperator. The scale is
automatically annotated as positive semi-definite on construction.
Parameters
loc : Float[Array, " N"]
Mean vector of the distribution.
scale : LinearOperator
Covariance matrix represented as a
:class:~gpjax.linalg.operators.LinearOperator (e.g.
:class:~gpjax.linalg.operators.Dense or
:class:~gpjax.linalg.operators.Diagonal).
Examples
import jax.numpy as jnp
from gpjax.distributions import GaussianDistribution
from gpjax.linalg.operators import Dense
mu = jnp.array([0.0, 1.0])
cov = Dense(jnp.eye(2))
dist = GaussianDistribution(loc=mu, scale=cov)
dist.mean
Array([0., 1.], dtype=float32)
dist.variance
Array([1., 1.], dtype=float32)
meanproperty
mean:Float[Array,' N']
Calculates the mean.
varianceproperty
variance:Float[Array,' N']
Calculates the marginal variance (diagonal of the covariance).
covariance_matrixproperty
covariance_matrix:Float[Array,'N N']
Property alias for :meth:covariance.
sample
sample(key,sample_shape=())
Draw samples from the distribution.
Generates samples via the reparameterisation trick:
where :math:\mathbf{L} is the lower Cholesky factor of the covariance.
Parameters
key : KeyArray
JAX PRNG key.
sample_shape : tuple of int, optional
Leading batch dimensions for the samples. Defaults to (),
returning a single sample.
Returns
Float[Array, "... N"]
Array of samples with shape (*sample_shape, N).
entropy
entropy()->ScalarFloat
Calculates the differential entropy of the distribution.