Objectives
Marginal log-likelihood objective for state-space GPs.
state_space_mll
Marginal log-likelihood via the square-root Kalman filter.
Internally
- Unwraps the posterior (resolving paramax-wrapped parameters).
- Builds the SDE via
to_sde(kernel). - Centres targets with
y - mean_function(X). - Computes
sigma_eff = sqrt(obs_stddevยฒ + prior.jitter). - Delegates to
kalman_filter.
Pure-JAX. Assumes time-sorted input. Unsorted times yield negative ฮt and
silently incorrect (NaN/garbage) results โ there is no internal sort, because
a data-dependent reorder inside this traced objective is avoided to keep it
jit/grad/MCMC-clean. Sorting and validation are the responsibility of the
eager state_space.fit* wrappers (sort_state_space_data warns and
reorders; validate_state_space_data checks finiteness/shape). Callers
invoking this objective directly (e.g. custom MCMC/optimisers) must pre-sort
by time, e.g. via sort_state_space_data(X, y, mask).
See plans/2026-04-21-state-space-gps-design.md ยงStage 1.
Example:
>>> import jax.numpy as jnp
>>> import gpjax as gpx
>>> from gpjax.state_space import StateSpacePrior, state_space_mll
>>> X = jnp.linspace(0.0, 5.0, 20).reshape(-1, 1)
>>> y = jnp.sin(X)
>>> prior = StateSpacePrior(
... mean_function=gpx.mean_functions.Zero(),
... kernel=gpx.kernels.Matern32(lengthscale=1.0, variance=1.0),
... )
>>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=20, obs_stddev=0.1)
>>> posterior = prior * likelihood
>>> train_data = gpx.Dataset(X=X, y=y)
>>> mll = state_space_mll(posterior, train_data)
>>> bool(jnp.isfinite(mll))
True