Oilmm
Orthogonal Instantaneous Linear Mixing Model (OILMM) for multi-output GPs.
OILMM achieves O(n³m) complexity instead of O(n³m³) by constraining the mixing matrix to have orthogonal columns, which causes the projected noise to be diagonal and enables inference to decompose into m independent single-output GP problems.
Reference
Bruinsma et al. (2020). "Scalable Exact Inference in Multi-Output Gaussian Processes." ICML.
OrthogonalMixingMatrix
Bases: Module
Mixing matrix H = U S^(1/2) with orthogonal columns.
Parameterizes an orthogonal mixing matrix for OILMM where: - U ∈ ℝ^(p×m) has orthonormal columns (U^T U = I_m) - S > 0 is a diagonal scaling matrix (m × m) - H = U S^(1/2) is the mixing matrix - T = S^(-1/2) U^T is the projection matrix
The orthogonality of U ensures that the projected noise is diagonal
Σ_T = T Σ T^T = σ²S^(-1) + D
where σ² is observation noise and D is latent noise.
Attributes:
-
num_outputs–Number of output dimensions (p)
-
num_latent_gps–Number of latent GP functions (m)
-
U_latent–Unconstrained matrix for SVD orthogonalization
-
S–Positive diagonal scaling
-
obs_noise_variance–Homogeneous observation noise (σ²)
-
latent_noise_variance–Per-latent heterogeneous noise (D), non-negative
Parameters:
-
num_outputs(int) –Number of output dimensions (p)
-
num_latent_gps(int) –Number of latent GPs (m), must satisfy m ≤ p
-
key(Array) –JAX PRNG key for initialization
U
property
Orthonormal columns via SVD.
Uses SVD to project U_latent onto the Stiefel manifold (orthonormal columns). This ensures U^T U = I_m exactly.
H
property
Mixing matrix H = U S^(1/2).
Maps from latent space (m dimensions) to output space (p dimensions). Each column is an orthogonal basis vector scaled by sqrt(S_i).
T
property
Projection matrix T = S^(-1/2) U^T.
Projects from output space (p dimensions) to latent space (m dimensions). This is the left pseudo-inverse of H: T @ H = I_m.
H_squared
property
Element-wise H² for fast diagonal variance reconstruction.
When computing marginal variances, we need H² @ latent_vars: var_p = sum_m H²_pm * var_m This property caches H² to avoid recomputation.
projected_noise_variance
property
Diagonal projected noise: Σ_T = σ²S^(-1) + D.
This is the noise variance for each independent latent GP after projection. The orthogonality of U ensures this is diagonal, which is what makes OILMM tractable.
Returns:
-
Float[Array, ' M']–Array of shape [M] with noise variance for each latent GP.
OILMMModel
OILMMModel(
num_outputs: int,
num_latent_gps: int,
kernel: AbstractKernel | list[AbstractKernel],
key: Array,
mean_function: Any = None,
)
Bases: Module
Orthogonal Instantaneous Linear Mixing Model.
OILMM decomposes multi-output GP inference into M independent single-output GP problems by using an orthogonal mixing matrix. This achieves O(n³m) complexity instead of O(n³m³).
The generative model is
x_i ~ GP(0, K(t,t')) for i=1..M (latent GPs) f(t) = H x(t) (mixing) y | f ~ N(f(t), Σ) (noise: Σ = σ²I + HDH^T)
The orthogonality constraint (U^T U = I) ensures the projected noise is diagonal: Σ_T = T Σ T^T = σ²S^(-1) + D enabling independent inference for each latent GP.
Attributes:
-
num_outputs–Number of output dimensions (p)
-
num_latent_gps–Number of latent GPs (m)
-
mixing_matrix–OrthogonalMixingMatrix containing H, T, noise params
-
latent_priors–Tuple of M independent Prior objects
Parameters:
-
num_outputs(int) –Number of output dimensions (p)
-
num_latent_gps(int) –Number of latent GPs (m), must satisfy m ≤ p
-
kernel(AbstractKernel | list[AbstractKernel]) –Kernel for latent GPs. If a single kernel, it is deep-copied M times so each latent GP has independent hyperparameters. If a list of M kernels, each is used directly.
-
key(Array) –JAX PRNG key
-
mean_function(Any, default:None) –Mean function for latent GPs (default: Zero)
condition_on_observations
Condition on observations to create posterior.
This implements the core OILMM inference algorithm: 1. Project observations: y_latent = T @ y 2. Condition M independent GPs on projected data 3. Return OILMMPosterior wrapping the M posteriors
Parameters:
-
dataset(Dataset) –Training data with X [N, D] and y [N, P]
Returns:
-
OILMMPosterior–OILMMPosterior containing M independent posteriors
OILMMPosterior
OILMMPosterior(
latent_posteriors: tuple,
latent_datasets: tuple,
mixing_matrix: OrthogonalMixingMatrix,
)
Posterior distribution for OILMM.
Wraps M independent ConjugatePosterior objects and provides a unified predict() interface that reconstructs predictions in output space.
This is a plain class (not nnx.Module) because it holds Dataset objects which are not JAX pytree nodes. The latent posteriors and mixing matrix are still nnx.Modules and participate in JAX transformations when accessed.
Attributes:
-
latent_posteriors–Tuple of M independent ConjugatePosterior objects
-
latent_datasets–Tuple of M projected training Datasets (one per latent GP)
-
mixing_matrix–OrthogonalMixingMatrix for reconstruction
-
num_latent_gps–Number of latent GPs (m)
Parameters:
-
latent_posteriors(tuple) –Tuple of M ConjugatePosterior objects
-
latent_datasets(tuple) –Tuple of M Dataset objects (projected training data)
-
mixing_matrix(OrthogonalMixingMatrix) –OrthogonalMixingMatrix containing H, T
predict
Predict at test locations.
Reconstructs predictions in output space from M independent latent posteriors: 1. Predict each latent GP independently 2. Reconstruct mean: f_mean = H @ latent_means 3. Reconstruct covariance: Σ_f = (H ⊗ I) Σ_x (H ⊗ I)^T
Parameters:
-
test_inputs(Float[Array, 'N D']) –Test input locations [N, D]
-
return_full_cov(bool, default:True) –If True, return full [NP, NP] covariance. If False, return diagonal covariance matrix.
Returns:
-
GaussianDistribution–GaussianDistribution with: - loc: [NP] flattened output-major - scale: Dense [NP, NP] covariance (full or diagonal)
oilmm_mll
Log marginal likelihood for the OILMM.
Implements Prop. 9 from Bruinsma et al. (2020):
log p(Y) = correction_terms + Σᵢ log N((TY)ᵢ | 0, Kᵢ + noise_i Iₙ)
The correction terms prevent the projection from collapsing and account for data in the (p - m) dimensions orthogonal to the mixing matrix.
Parameters:
-
model(OILMMModel) –OILMMModel with parameters to evaluate.
-
data(Dataset) –Training data with X [N, D] and y [N, P].
Returns:
-
ScalarFloat–Scalar log marginal likelihood.
create_oilmm
create_oilmm(
num_outputs: int,
num_latent_gps: int,
key: Array,
kernel: AbstractKernel
| list[AbstractKernel]
| None = None,
mean_function: Any = None,
) -> OILMMModel
Create OILMM model with shared kernel across latents.
Parameters:
-
num_outputs(int) –Number of output dimensions (p)
-
num_latent_gps(int) –Number of latent GPs (m)
-
key(Array) –JAX PRNG key
-
kernel(AbstractKernel | list[AbstractKernel] | None, default:None) –Kernel for latent GPs (default: RBF)
-
mean_function(Any, default:None) –Mean function for latent GPs (default: Zero)
Returns:
-
OILMMModel–Initialized OILMMModel
Example
import gpjax as gpx import jax.random as jr model = gpx.models.create_oilmm( ... num_outputs=5, ... num_latent_gps=2, ... key=jr.key(42), ... kernel=gpx.kernels.Matern52() ... )
create_oilmm_with_kernels
create_oilmm_with_kernels(
latent_kernels: list[AbstractKernel],
num_outputs: int,
key: Array,
mean_function: Any = None,
) -> OILMMModel
Create OILMM with custom kernel per latent GP.
Parameters:
-
latent_kernels(list[AbstractKernel]) –List of M kernels, one per latent GP
-
num_outputs(int) –Number of output dimensions (p)
-
key(Array) –JAX PRNG key
-
mean_function(Any, default:None) –Mean function (shared, default: Zero)
Returns:
-
OILMMModel–OILMMModel with heterogeneous latent kernels
Example
import gpjax as gpx import jax.random as jr model = gpx.models.create_oilmm_with_kernels( ... latent_kernels=[gpx.kernels.RBF(), gpx.kernels.Matern52()], ... num_outputs=6, ... key=jr.key(42) ... )
create_oilmm_from_data
create_oilmm_from_data(
dataset: Dataset,
num_latent_gps: int,
key: Array,
kernel: AbstractKernel = None,
mean_function: Any = None,
) -> OILMMModel
Create OILMM with data-informed initialization of mixing matrix.
Initializes U to the top M eigenvectors and S to the top M eigenvalues of the empirical covariance matrix of the outputs. Near-zero eigenvalues are clamped to 1e-6 for numerical stability. This can provide better initialization than random, especially when outputs have clear correlation structure.
Parameters:
-
dataset(Dataset) –Training data with y [N, P]
-
num_latent_gps(int) –Number of latent GPs (m)
-
key(Array) –JAX PRNG key
-
kernel(AbstractKernel, default:None) –Kernel for latent GPs (default: RBF)
-
mean_function(Any, default:None) –Mean function (default: Zero)
Returns:
-
OILMMModel–OILMMModel with U initialized to top M eigenvectors and S to
-
OILMMModel–top M eigenvalues
Example
import gpjax as gpx import jax.numpy as jnp import jax.random as jr X = jnp.linspace(0, 1, 50).reshape(-1, 1) y = jnp.column_stack([jnp.sin(X), jnp.cos(X)]) data = gpx.Dataset(X=X, y=y) model = gpx.models.create_oilmm_from_data( ... dataset=data, ... num_latent_gps=2, ... key=jr.key(42) ... )