Inference
Square-root Kalman filter and RTS smoother for state-space GPs.
See plans/2026-04-21-state-space-gps-design.md §Stage 2 and §Stage 3.
kalman_filter
Square-root Kalman filter marginal log-likelihood.
Pure-JAX, double-scan with checkpointed inner chunks. Caller is responsible
for centring targets (subtracting the mean function), sorting by time, and
computing sigma_eff = sqrt(σ_y² + jitter).
The first entry of time_steps must be 0.0 so the first predict step
is the identity and the filter starts from the prior at t = 0. The
initial filter state is (mean=0, L=sde.stationary_state_cov_sqrt) —
centred targets imply a zero prior mean.
Memory characteristics: an outer lax.scan iterates over chunks of
chunk_size steps; an inner lax.scan (wrapped in
:func:jax.checkpoint) runs the per-step filter inside a chunk. Under
reverse-mode autodiff, only the chunk-boundary carries are saved on the
tape — interior carries are recomputed on the backward pass — giving an
AD memory footprint of O(sqrt(N) * d^2) for the default
chunk_size = round(sqrt(N)).
To accommodate num_steps that is not a multiple of the resolved chunk
size, the inputs are padded with no-op steps (time_step = 0 and
is_observed = False). Padding steps are exact algebraic no-ops because
every SDE returns (A, L_Q) = (I, 0) at dt = 0 and the masked update
contributes zero log-likelihood.
Parameters
sde : LinearSDE
State-space SDE; sde.discretise(dt) is called per scan step.
centred_targets : Float[Array, "num_train"]
Targets with the mean function subtracted.
time_steps : Float[Array, "num_train"]
time_steps[0] = 0 and time_steps[i] = t_i - t_{i-1} for i > 0.
is_observed : Bool[Array, "num_train"]
At indices where this is False the update step is a no-op (predict
only). Useful for masked observations.
sigma_eff : Float[Array, ""]
Effective observation standard deviation.
chunk_size : int | None, keyword-only
Static chunk length for the inner scan. None resolves to
round(sqrt(num_steps)) (clamped to >= 1). Values larger than
num_steps are clamped to num_steps (single chunk, no padding).
Returns
Float[Array, ""]
Scalar marginal log-likelihood Σ_i log p(y_i | y_{<i}).
See plans/2026-04-21-state-space-gps-design.md §Stage 2.
rts_smoother
Square-root RTS smoother.
Runs the standard Särkkä & Solin (2019) §10.7 backward recursion on the
forward filter trajectory. Internally materialises P = L @ L.T for the
smoother gain computation and re-roots the smoothed covariance after each
backward step with :func:gpjax.state_space.sde._psd_sqrt. The returned
Ls are therefore non-triangular V·Λ^½ square roots (neither Cholesky
nor symmetric); consumers must rely only on L @ Lᵀ. The recursion is
.. math::
G_i &= P^{\text{filt}}_i A_{i+1}^\top (P^{\text{pred}}_{i+1})^{-1} \\
m^{\text{smooth}}_i &= m^{\text{filt}}_i + G_i (m^{\text{smooth}}_{i+1} - m^{\text{pred}}_{i+1}) \\
P^{\text{smooth}}_i &= P^{\text{filt}}_i + G_i (P^{\text{smooth}}_{i+1} - P^{\text{pred}}_{i+1}) G_i^\top
The last step has no future, so its smoothed state equals its filtered state.
Parameters
sde : LinearSDE
State-space SDE used in the forward pass; sde.discretise(dt) is
called once per backward step.
forward_outputs : tuple
Quadruple (means_updated, Ls_updated, means_predicted, Ls_predicted)
as returned by :func:_sqrt_filter_forward.
time_steps : Float[Array, "num_train"]
Same time_steps that drove the forward pass; time_steps[i+1] is
the inter-step dt between filtered index i and predicted index
i + 1.
Returns
smoothed_means : Float[Array, "num_train state_dim"] smoothed_Ls : Float[Array, "num_train state_dim state_dim"]
See plans/2026-04-21-state-space-gps-design.md §Stage 3.