Orthogonal Instantaneous Linear Mixing Model (OILMM) for multi-output GPs.
OILMM achieves O(nΒ³m) complexity instead of O(nΒ³mΒ³) by constraining the mixing
matrix to have orthogonal columns, which causes the projected noise to be
diagonal and enables inference to decompose into m independent single-output
GP problems.
Reference
Bruinsma et al. (2020). "Scalable Exact Inference in Multi-Output Gaussian
Processes." ICML.
Mixing matrix H = U S^(1/2) with orthogonal columns.
Parameterizes an orthogonal mixing matrix for OILMM where:
- U β β^(pΓm) has orthonormal columns (U^T U = I_m)
- S > 0 is a diagonal scaling matrix (m Γ m)
- H = U S^(1/2) is the mixing matrix
- T = S^(-1/2) U^T is the projection matrix
The orthogonality of U ensures that the projected noise is diagonal
Ξ£_T = T Ξ£ T^T = ΟΒ²S^(-1) + D
where ΟΒ² is observation noise and D is latent noise.
Attributes:
num_outputs
β
Number of output dimensions (p)
num_latent_gps
β
Number of latent GP functions (m)
U_latent
β
Unconstrained matrix for SVD orthogonalization
S
β
Positive diagonal scaling
obs_noise_variance
β
Homogeneous observation noise (ΟΒ²)
latent_noise_variance
β
Per-latent heterogeneous noise (D), non-negative
Parameters:
num_outputs
(int)
β
Number of output dimensions (p)
num_latent_gps
(int)
β
Number of latent GPs (m), must satisfy m β€ p
key
(Array)
β
JAX PRNG key for initialization
Uproperty
U:Float[Array,'P M']
Orthonormal columns via SVD.
Uses SVD to project U_latent onto the Stiefel manifold (orthonormal columns).
This ensures U^T U = I_m exactly.
sqrt_Sproperty
sqrt_S:Float[Array,' M']
Square root of S diagonal: S^(1/2).
inv_sqrt_Sproperty
inv_sqrt_S:Float[Array,' M']
Inverse square root of S diagonal: S^(-1/2).
Hproperty
H:Float[Array,'P M']
Mixing matrix H = U S^(1/2).
Maps from latent space (m dimensions) to output space (p dimensions).
Each column is an orthogonal basis vector scaled by sqrt(S_i).
Tproperty
T:Float[Array,'M P']
Projection matrix T = S^(-1/2) U^T.
Projects from output space (p dimensions) to latent space (m dimensions).
This is the left pseudo-inverse of H: T @ H = I_m.
H_squaredproperty
H_squared:Float[Array,'P M']
Element-wise HΒ² for fast diagonal variance reconstruction.
When computing marginal variances, we need HΒ² @ latent_vars:
var_p = sum_m HΒ²_pm * var_m
This property caches HΒ² to avoid recomputation.
projected_noise_varianceproperty
projected_noise_variance:Float[Array,' M']
Diagonal projected noise: Ξ£_T = ΟΒ²S^(-1) + D.
This is the noise variance for each independent latent GP after projection.
The orthogonality of U ensures this is diagonal, which is what makes
OILMM tractable.
Returns:
Float[Array, ' M']
β
Array of shape [M] with noise variance for each latent GP.
OILMM decomposes multi-output GP inference into M independent single-output
GP problems by using an orthogonal mixing matrix. This achieves O(nΒ³m)
complexity instead of O(nΒ³mΒ³).
The generative model is
x_i ~ GP(0, K(t,t')) for i=1..M (latent GPs)
f(t) = H x(t) (mixing)
y | f ~ N(f(t), Ξ£) (noise: Ξ£ = ΟΒ²I + HDH^T)
The orthogonality constraint (U^T U = I) ensures the projected noise is diagonal:
Ξ£_T = T Ξ£ T^T = ΟΒ²S^(-1) + D
enabling independent inference for each latent GP.
Kernel for latent GPs. If a single kernel, it is deep-copied
M times so each latent GP has independent hyperparameters. If a
list of M kernels, each is used directly.
This implements the core OILMM inference algorithm:
1. Project observations: y_latent = T @ y
2. Condition M independent GPs on projected data
3. Return OILMMPosterior wrapping the M posteriors
Wraps M independent ConjugatePosterior objects and provides a unified
predict() interface that reconstructs predictions in output space.
This is a plain class (not nnx.Module) because it holds Dataset objects
which are not JAX pytree nodes. The latent posteriors and mixing matrix
are still nnx.Modules and participate in JAX transformations when accessed.
Attributes:
latent_posteriors
β
Tuple of M independent ConjugatePosterior objects
latent_datasets
β
Tuple of M projected training Datasets (one per latent GP)
mixing_matrix
β
OrthogonalMixingMatrix for reconstruction
num_latent_gps
β
Number of latent GPs (m)
Parameters:
latent_posteriors
(tuple)
β
Tuple of M ConjugatePosterior objects
latent_datasets
(tuple)
β
Tuple of M Dataset objects (projected training data)
import gpjax as gpx
import jax.random as jr
model = gpx.models.create_oilmm_with_kernels(
... latent_kernels=[gpx.kernels.RBF(), gpx.kernels.Matern52()],
... num_outputs=6,
... key=jr.key(42)
... )
Create OILMM with data-informed initialization of mixing matrix.
Initializes U to the top M eigenvectors and S to the top M eigenvalues of
the empirical covariance matrix of the outputs. Near-zero eigenvalues are
clamped to 1e-6 for numerical stability. This can provide better
initialization than random, especially when outputs have clear correlation
structure.
import gpjax as gpx
import jax.numpy as jnp
import jax.random as jr
X = jnp.linspace(0, 1, 50).reshape(-1, 1)
y = jnp.column_stack([jnp.sin(X), jnp.cos(X)])
data = gpx.Dataset(X=X, y=y)
model = gpx.models.create_oilmm_from_data(
... dataset=data,
... num_latent_gps=2,
... key=jr.key(42)
... )