Skip to content

Distributions

GaussianDistribution

GaussianDistribution(
    loc: Optional[Float[Array, " N"]],
    scale: Optional[LinearOperator],
    validate_args=None,
)

Bases: Distribution

Multivariate Gaussian distribution for GP predictions.

This is the return type of all predict() methods in GPJax. It wraps a mean vector and a covariance :class:~gpjax.linalg.operators.LinearOperator, providing methods for sampling, computing log-probabilities, and evaluating KL divergences.

The distribution is parameterised as

.. math::

p(\mathbf{x}) = \mathcal{N}(\mathbf{x}; \boldsymbol{\mu}, \mathbf{\Sigma})

where :math:\boldsymbol{\mu} is the loc (mean) vector and :math:\mathbf{\Sigma} is represented by the scale :class:~gpjax.linalg.operators.LinearOperator. The scale is automatically annotated as positive semi-definite on construction.

Parameters

loc : Float[Array, " N"] Mean vector of the distribution. scale : LinearOperator Covariance matrix represented as a :class:~gpjax.linalg.operators.LinearOperator (e.g. :class:~gpjax.linalg.operators.Dense or :class:~gpjax.linalg.operators.Diagonal).

Examples

import jax.numpy as jnp from gpjax.distributions import GaussianDistribution from gpjax.linalg.operators import Dense mu = jnp.array([0.0, 1.0]) cov = Dense(jnp.eye(2)) dist = GaussianDistribution(loc=mu, scale=cov) dist.mean Array([0., 1.], dtype=float32) dist.variance Array([1., 1.], dtype=float32)

mean property

mean: Float[Array, ' N']

Calculates the mean.

variance property

variance: Float[Array, ' N']

Calculates the marginal variance (diagonal of the covariance).

covariance_matrix property

covariance_matrix: Float[Array, 'N N']

Property alias for :meth:covariance.

sample

sample(key, sample_shape=())

Draw samples from the distribution.

Generates samples via the reparameterisation trick:

.. math::

\mathbf{x} = \boldsymbol{\mu} + \mathbf{L}\mathbf{z},
\quad \mathbf{z} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})

where :math:\mathbf{L} is the lower Cholesky factor of the covariance.

Parameters

key : KeyArray JAX PRNG key. sample_shape : tuple of int, optional Leading batch dimensions for the samples. Defaults to (), returning a single sample.

Returns

Float[Array, "... N"] Array of samples with shape (*sample_shape, N).

entropy

entropy() -> ScalarFloat

Calculates the differential entropy of the distribution.

.. math::

H[p] = \tfrac{1}{2}\bigl(N(1 + \ln 2\pi) + \ln|\mathbf{\Sigma}|\bigr)
Returns

ScalarFloat Entropy in nats.

median

median() -> Float[Array, ' N']

Calculates the median (equal to the mean for a Gaussian).

mode

mode() -> Float[Array, ' N']

Calculates the mode (equal to the mean for a Gaussian).

covariance

covariance() -> Float[Array, 'N N']

Materialises the full covariance matrix as a dense array.

Returns

Float[Array, "N N"] Dense covariance matrix.

stddev

stddev() -> Float[Array, ' N']

Calculates the marginal standard deviation.

log_prob

log_prob(y: Float[Array, ' N']) -> ScalarFloat

Calculates the log pdf of the multivariate Gaussian.

.. math::

\log p(\mathbf{y}) = -\tfrac{1}{2}\bigl[
    N\ln 2\pi + \ln|\mathbf{\Sigma}|
    + (\mathbf{y} - \boldsymbol{\mu})^\top
      \mathbf{\Sigma}^{-1}
      (\mathbf{y} - \boldsymbol{\mu})
\bigr]
Parameters

y : Float[Array, " N"] Point at which to evaluate the log-density.

Returns

ScalarFloat Log probability.

kl_divergence

kl_divergence(other: GaussianDistribution) -> ScalarFloat

KL divergence from self to other.

Computes :math:\operatorname{KL}[q \| p] where self is q and other is p.

Parameters

other : GaussianDistribution The reference distribution p.

Returns

ScalarFloat KL divergence in nats.