Skip to content

Inference

Square-root Kalman filter and RTS smoother for state-space GPs.

See plans/2026-04-21-state-space-gps-design.md §Stage 2 and §Stage 3.

kalman_filter

kalman_filter(
    sde,
    centred_targets,
    time_steps,
    is_observed,
    sigma_eff,
    *,
    chunk_size=None,
)

Square-root Kalman filter marginal log-likelihood.

Pure-JAX, double-scan with checkpointed inner chunks. Caller is responsible for centring targets (subtracting the mean function), sorting by time, and computing sigma_eff = sqrt(σ_y² + jitter).

The first entry of time_steps must be 0.0 so the first predict step is the identity and the filter starts from the prior at t = 0. The initial filter state is (mean=0, L=sde.stationary_state_cov_sqrt) — centred targets imply a zero prior mean.

Memory characteristics: an outer lax.scan iterates over chunks of chunk_size steps; an inner lax.scan (wrapped in :func:jax.checkpoint) runs the per-step filter inside a chunk. Under reverse-mode autodiff, only the chunk-boundary carries are saved on the tape — interior carries are recomputed on the backward pass — giving an AD memory footprint of O(sqrt(N) * d^2) for the default chunk_size = round(sqrt(N)).

To accommodate num_steps that is not a multiple of the resolved chunk size, the inputs are padded with no-op steps (time_step = 0 and is_observed = False). Padding steps are exact algebraic no-ops because every SDE returns (A, L_Q) = (I, 0) at dt = 0 and the masked update contributes zero log-likelihood.

Parameters

sde : LinearSDE State-space SDE; sde.discretise(dt) is called per scan step. centred_targets : Float[Array, "num_train"] Targets with the mean function subtracted. time_steps : Float[Array, "num_train"] time_steps[0] = 0 and time_steps[i] = t_i - t_{i-1} for i > 0. is_observed : Bool[Array, "num_train"] At indices where this is False the update step is a no-op (predict only). Useful for masked observations. sigma_eff : Float[Array, ""] Effective observation standard deviation. chunk_size : int | None, keyword-only Static chunk length for the inner scan. None resolves to round(sqrt(num_steps)) (clamped to >= 1). Values larger than num_steps are clamped to num_steps (single chunk, no padding).

Returns

Float[Array, ""] Scalar marginal log-likelihood Σ_i log p(y_i | y_{<i}).

See plans/2026-04-21-state-space-gps-design.md §Stage 2.

rts_smoother

rts_smoother(sde, forward_outputs, time_steps)

Square-root RTS smoother.

Runs the standard Särkkä & Solin (2019) §10.7 backward recursion on the forward filter trajectory. Internally materialises P = L @ L.T for the smoother gain computation and re-roots the smoothed covariance after each backward step with :func:gpjax.state_space.sde._psd_sqrt. The returned Ls are therefore non-triangular V·Λ^½ square roots (neither Cholesky nor symmetric); consumers must rely only on L @ Lᵀ. The recursion is

.. math::

G_i &= P^{\text{filt}}_i A_{i+1}^\top (P^{\text{pred}}_{i+1})^{-1} \\
m^{\text{smooth}}_i &= m^{\text{filt}}_i + G_i (m^{\text{smooth}}_{i+1} - m^{\text{pred}}_{i+1}) \\
P^{\text{smooth}}_i &= P^{\text{filt}}_i + G_i (P^{\text{smooth}}_{i+1} - P^{\text{pred}}_{i+1}) G_i^\top

The last step has no future, so its smoothed state equals its filtered state.

Parameters

sde : LinearSDE State-space SDE used in the forward pass; sde.discretise(dt) is called once per backward step. forward_outputs : tuple Quadruple (means_updated, Ls_updated, means_predicted, Ls_predicted) as returned by :func:_sqrt_filter_forward. time_steps : Float[Array, "num_train"] Same time_steps that drove the forward pass; time_steps[i+1] is the inter-step dt between filtered index i and predicted index i + 1.

Returns

smoothed_means : Float[Array, "num_train state_dim"] smoothed_Ls : Float[Array, "num_train state_dim state_dim"]

See plans/2026-04-21-state-space-gps-design.md §Stage 3.