Skip to content

GPs

StateSpacePrior and StateSpaceConjugatePosterior classes.

See plans/2026-04-21-state-space-gps-design.md.

StateSpacePrior

StateSpacePrior(
    kernel: K, mean_function: M, jitter: float = 1e-06
)

Bases: Prior

Prior for a state-space (Markovian) GP.

Identical to gpjax.gps.Prior except predictions are diagonal-only (the prior is stationary in time, so off-diagonal covariance carries no extra information for v1's diagonal-only predictive contract).

Predictive contract (v1): prediction returns diagonal (marginal) covariance only; the marginals are exact. A dense joint predictive is not implemented in v1 and is tracked as a follow-up. This predictive is therefore not Liskov-substitutable for a dense gpjax.gps.ConjugatePosterior predictive.

Example:

    >>> import gpjax as gpx
    >>> from gpjax.state_space import StateSpacePrior
    >>> prior = StateSpacePrior(
    ...     mean_function=gpx.mean_functions.Zero(),
    ...     kernel=gpx.kernels.Matern32(lengthscale=1.0, variance=1.0),
    ... )
    >>> isinstance(prior.kernel, gpx.kernels.Matern32)
    True

Parameters:

  • kernel (K) –

    kernel object inheriting from AbstractKernel.

  • mean_function (M) –

    mean function object inheriting from AbstractMeanFunction.

sample_approx

sample_approx(
    num_samples: int,
    key: KeyArray,
    num_features: Optional[int] = 100,
) -> FunctionalSample

Approximate samples from the Gaussian process prior.

Build an approximate sample from the Gaussian process prior. This method provides a function that returns the evaluations of a sample across any given inputs.

In particular, we approximate the Gaussian processes' prior as the finite feature approximation \(\hat{f}(x) = \sum_{i=1}^m\phi_i(x)\theta_i\) where \(\phi_i\) are \(m\) features sampled from the Fourier feature decomposition of the model's kernel and \(\theta_i\) are samples from a unit Gaussian.

A key property of such functional samples is that the same sample draw is evaluated for all queries. Consistency is a property that is prohibitively costly to ensure when sampling exactly from the GP prior, as the cost of exact sampling scales cubically with the size of the sample. In contrast, finite feature representations can be evaluated with constant cost regardless of the required number of queries.

In the following example, we build 10 such samples and then evaluate them over the interval \([0, 1]\):

For a prior distribution, the following code snippet will build and evaluate an approximate sample.

Example:

    >>> import gpjax as gpx
    >>> import jax.numpy as jnp
    >>> import jax.random as jr
    >>> key = jr.key(123)
    >>>
    >>> meanf = gpx.mean_functions.Zero()
    >>> kernel = gpx.kernels.RBF(n_dims=1)
    >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
    >>>
    >>> sample_fn = prior.sample_approx(10, key)
    >>> sample_fn(jnp.linspace(0, 1, 100).reshape(-1, 1))

Parameters:

  • num_samples (int) –

    The desired number of samples.

  • key (KeyArray) –

    The random seed used for the sample(s).

  • num_features (int, default: 100 ) –

    The number of features used when approximating the kernel.

Returns:

  • FunctionalSample ( FunctionalSample ) –

    A function representing an approximate sample from the Gaussian process prior.

StateSpaceConjugatePosterior

StateSpaceConjugatePosterior(
    prior: AbstractPrior[M, K],
    likelihood: L,
    jitter: float = 1e-06,
)

Bases: ConjugatePosterior

Conjugate posterior for a state-space (Markovian) GP.

v1 prediction surface
  • predict : smoothed-latent prediction (Phase 10)
  • predict_filter : causal filtered prediction (Phase 10)
  • __call__ : delegates to predict

Both predict and predict_filter reject return_covariance_type="dense" in favour of v1's diagonal-only contract before any further dispatch.

Predictive contract (v1): prediction returns diagonal (marginal) covariance only; the marginals are exact. A dense joint predictive is not implemented in v1 and is tracked as a follow-up. This predictive is therefore not Liskov-substitutable for a dense gpjax.gps.ConjugatePosterior predictive.

Example:

    >>> import gpjax as gpx
    >>> from gpjax.state_space import StateSpacePrior
    >>> prior = StateSpacePrior(
    ...     mean_function=gpx.mean_functions.Zero(),
    ...     kernel=gpx.kernels.Matern32(lengthscale=1.0, variance=1.0),
    ... )
    >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=20, obs_stddev=0.1)
    >>> posterior = prior * likelihood
    >>> posterior.__class__.__name__
    'StateSpaceConjugatePosterior'

Parameters:

  • prior (AbstractPrior) –

    The prior distribution.

  • likelihood (AbstractLikelihood) –

    The likelihood distribution.

  • jitter (float, default: 1e-06 ) –

    A small constant added to the diagonal of the covariance matrix to ensure numerical stability.

sample_approx

sample_approx(
    num_samples: int,
    train_data: Dataset,
    key: KeyArray,
    num_features: int | None = 100,
) -> FunctionalSample

Draw approximate samples from the Gaussian process posterior.

Build an approximate sample from the Gaussian process posterior. This method provides a function that returns the evaluations of a sample across any given inputs.

Unlike when building approximate samples from a Gaussian process prior, decompositions based on Fourier features alone rarely give accurate samples. Therefore, we must also include an additional set of features (known as canonical features) to better model the transition from Gaussian process prior to Gaussian process posterior. For more details see Wilson et. al. (2020).

In particular, we approximate the Gaussian processes' posterior as the finite feature approximation \(\hat{f}(x) = \sum_{i=1}^m \phi_i(x)\theta_i + \sum{j=1}^N v_jk(.,x_j)\) where \(\phi_i\) are m features sampled from the Fourier feature decomposition of the model's kernel and \(k(., x_j)\) are N canonical features. The Fourier weights \(\theta_i\) are samples from a unit Gaussian. See Wilson et. al. (2020) for expressions for the canonical weights \(v_j\).

A key property of such functional samples is that the same sample draw is evaluated for all queries. Consistency is a property that is prohibitively costly to ensure when sampling exactly from the GP prior, as the cost of exact sampling scales cubically with the size of the sample. In contrast, finite feature representations can be evaluated with constant cost regardless of the required number of queries.

Parameters:

  • num_samples (int) –

    The desired number of samples.

  • key (KeyArray) –

    The random seed used for the sample(s).

  • num_features (int, default: 100 ) –

    The number of features used when approximating the kernel.

Returns:

  • FunctionalSample ( FunctionalSample ) –

    A function representing an approximate sample from the Gaussian process prior.