Skip to content

Oilmm

Orthogonal Instantaneous Linear Mixing Model (OILMM) for multi-output GPs.

OILMM achieves O(nΒ³m) complexity instead of O(nΒ³mΒ³) by constraining the mixing matrix to have orthogonal columns, which causes the projected noise to be diagonal and enables inference to decompose into m independent single-output GP problems.

Reference

Bruinsma et al. (2020). "Scalable Exact Inference in Multi-Output Gaussian Processes." ICML.

OrthogonalMixingMatrix

OrthogonalMixingMatrix(
    num_outputs: int, num_latent_gps: int, key: Array
)

Bases: Module

Mixing matrix H = U S^(1/2) with orthogonal columns.

Parameterizes an orthogonal mixing matrix for OILMM where: - U ∈ ℝ^(pΓ—m) has orthonormal columns (U^T U = I_m) - S > 0 is a diagonal scaling matrix (m Γ— m) - H = U S^(1/2) is the mixing matrix - T = S^(-1/2) U^T is the projection matrix

The orthogonality of U ensures that the projected noise is diagonal

Ξ£_T = T Ξ£ T^T = σ²S^(-1) + D

where σ² is observation noise and D is latent noise.

Attributes:

  • num_outputs –

    Number of output dimensions (p)

  • num_latent_gps –

    Number of latent GP functions (m)

  • U_latent –

    Unconstrained matrix for SVD orthogonalization

  • S –

    Positive diagonal scaling

  • obs_noise_variance –

    Homogeneous observation noise (σ²)

  • latent_noise_variance –

    Per-latent heterogeneous noise (D), non-negative

Parameters:

  • num_outputs (int) –

    Number of output dimensions (p)

  • num_latent_gps (int) –

    Number of latent GPs (m), must satisfy m ≀ p

  • key (Array) –

    JAX PRNG key for initialization

U property

U: Float[Array, 'P M']

Orthonormal columns via SVD.

Uses SVD to project U_latent onto the Stiefel manifold (orthonormal columns). This ensures U^T U = I_m exactly.

sqrt_S property

sqrt_S: Float[Array, ' M']

Square root of S diagonal: S^(1/2).

inv_sqrt_S property

inv_sqrt_S: Float[Array, ' M']

Inverse square root of S diagonal: S^(-1/2).

H property

H: Float[Array, 'P M']

Mixing matrix H = U S^(1/2).

Maps from latent space (m dimensions) to output space (p dimensions). Each column is an orthogonal basis vector scaled by sqrt(S_i).

T property

T: Float[Array, 'M P']

Projection matrix T = S^(-1/2) U^T.

Projects from output space (p dimensions) to latent space (m dimensions). This is the left pseudo-inverse of H: T @ H = I_m.

H_squared property

H_squared: Float[Array, 'P M']

Element-wise HΒ² for fast diagonal variance reconstruction.

When computing marginal variances, we need HΒ² @ latent_vars: var_p = sum_m HΒ²_pm * var_m This property caches HΒ² to avoid recomputation.

projected_noise_variance property

projected_noise_variance: Float[Array, ' M']

Diagonal projected noise: Ξ£_T = σ²S^(-1) + D.

This is the noise variance for each independent latent GP after projection. The orthogonality of U ensures this is diagonal, which is what makes OILMM tractable.

Returns:

  • Float[Array, ' M'] –

    Array of shape [M] with noise variance for each latent GP.

OILMMModel

OILMMModel(
    num_outputs: int,
    num_latent_gps: int,
    kernel: AbstractKernel | list[AbstractKernel],
    key: Array,
    mean_function: Any = None,
)

Bases: Module

Orthogonal Instantaneous Linear Mixing Model.

OILMM decomposes multi-output GP inference into M independent single-output GP problems by using an orthogonal mixing matrix. This achieves O(nΒ³m) complexity instead of O(nΒ³mΒ³).

The generative model is

x_i ~ GP(0, K(t,t')) for i=1..M (latent GPs) f(t) = H x(t) (mixing) y | f ~ N(f(t), Ξ£) (noise: Ξ£ = σ²I + HDH^T)

The orthogonality constraint (U^T U = I) ensures the projected noise is diagonal: Ξ£_T = T Ξ£ T^T = σ²S^(-1) + D enabling independent inference for each latent GP.

Attributes:

  • num_outputs –

    Number of output dimensions (p)

  • num_latent_gps –

    Number of latent GPs (m)

  • mixing_matrix –

    OrthogonalMixingMatrix containing H, T, noise params

  • latent_priors –

    Tuple of M independent Prior objects

Parameters:

  • num_outputs (int) –

    Number of output dimensions (p)

  • num_latent_gps (int) –

    Number of latent GPs (m), must satisfy m ≀ p

  • kernel (AbstractKernel | list[AbstractKernel]) –

    Kernel for latent GPs. If a single kernel, it is deep-copied M times so each latent GP has independent hyperparameters. If a list of M kernels, each is used directly.

  • key (Array) –

    JAX PRNG key

  • mean_function (Any, default: None ) –

    Mean function for latent GPs (default: Zero)

condition_on_observations

condition_on_observations(
    dataset: Dataset,
) -> OILMMPosterior

Condition on observations to create posterior.

This implements the core OILMM inference algorithm: 1. Project observations: y_latent = T @ y 2. Condition M independent GPs on projected data 3. Return OILMMPosterior wrapping the M posteriors

Parameters:

  • dataset (Dataset) –

    Training data with X [N, D] and y [N, P]

Returns:

  • OILMMPosterior –

    OILMMPosterior containing M independent posteriors

OILMMPosterior

OILMMPosterior(
    latent_posteriors: tuple,
    latent_datasets: tuple,
    mixing_matrix: OrthogonalMixingMatrix,
)

Posterior distribution for OILMM.

Wraps M independent ConjugatePosterior objects and provides a unified predict() interface that reconstructs predictions in output space.

This is a plain class (not nnx.Module) because it holds Dataset objects which are not JAX pytree nodes. The latent posteriors and mixing matrix are still nnx.Modules and participate in JAX transformations when accessed.

Attributes:

  • latent_posteriors –

    Tuple of M independent ConjugatePosterior objects

  • latent_datasets –

    Tuple of M projected training Datasets (one per latent GP)

  • mixing_matrix –

    OrthogonalMixingMatrix for reconstruction

  • num_latent_gps –

    Number of latent GPs (m)

Parameters:

  • latent_posteriors (tuple) –

    Tuple of M ConjugatePosterior objects

  • latent_datasets (tuple) –

    Tuple of M Dataset objects (projected training data)

  • mixing_matrix (OrthogonalMixingMatrix) –

    OrthogonalMixingMatrix containing H, T

predict

predict(
    test_inputs: Float[Array, "N D"],
    return_full_cov: bool = True,
) -> GaussianDistribution

Predict at test locations.

Reconstructs predictions in output space from M independent latent posteriors: 1. Predict each latent GP independently 2. Reconstruct mean: f_mean = H @ latent_means 3. Reconstruct covariance: Ξ£_f = (H βŠ— I) Ξ£_x (H βŠ— I)^T

Parameters:

  • test_inputs (Float[Array, 'N D']) –

    Test input locations [N, D]

  • return_full_cov (bool, default: True ) –

    If True, return full [NP, NP] covariance. If False, return diagonal covariance matrix.

Returns:

  • GaussianDistribution –

    GaussianDistribution with: - loc: [NP] flattened output-major - scale: Dense [NP, NP] covariance (full or diagonal)

oilmm_mll

oilmm_mll(model: OILMMModel, data: Dataset) -> ScalarFloat

Log marginal likelihood for the OILMM.

Implements Prop. 9 from Bruinsma et al. (2020):

log p(Y) = correction_terms + Ξ£α΅’ log N((TY)α΅’ | 0, Kα΅’ + noise_i Iβ‚™)

The correction terms prevent the projection from collapsing and account for data in the (p - m) dimensions orthogonal to the mixing matrix.

Parameters:

  • model (OILMMModel) –

    OILMMModel with parameters to evaluate.

  • data (Dataset) –

    Training data with X [N, D] and y [N, P].

Returns:

  • ScalarFloat –

    Scalar log marginal likelihood.

create_oilmm

create_oilmm(
    num_outputs: int,
    num_latent_gps: int,
    key: Array,
    kernel: AbstractKernel
    | list[AbstractKernel]
    | None = None,
    mean_function: Any = None,
) -> OILMMModel

Create OILMM model with shared kernel across latents.

Parameters:

  • num_outputs (int) –

    Number of output dimensions (p)

  • num_latent_gps (int) –

    Number of latent GPs (m)

  • key (Array) –

    JAX PRNG key

  • kernel (AbstractKernel | list[AbstractKernel] | None, default: None ) –

    Kernel for latent GPs (default: RBF)

  • mean_function (Any, default: None ) –

    Mean function for latent GPs (default: Zero)

Returns:

Example

import gpjax as gpx import jax.random as jr model = gpx.models.create_oilmm( ... num_outputs=5, ... num_latent_gps=2, ... key=jr.key(42), ... kernel=gpx.kernels.Matern52() ... )

create_oilmm_with_kernels

create_oilmm_with_kernels(
    latent_kernels: list[AbstractKernel],
    num_outputs: int,
    key: Array,
    mean_function: Any = None,
) -> OILMMModel

Create OILMM with custom kernel per latent GP.

Parameters:

  • latent_kernels (list[AbstractKernel]) –

    List of M kernels, one per latent GP

  • num_outputs (int) –

    Number of output dimensions (p)

  • key (Array) –

    JAX PRNG key

  • mean_function (Any, default: None ) –

    Mean function (shared, default: Zero)

Returns:

  • OILMMModel –

    OILMMModel with heterogeneous latent kernels

Example

import gpjax as gpx import jax.random as jr model = gpx.models.create_oilmm_with_kernels( ... latent_kernels=[gpx.kernels.RBF(), gpx.kernels.Matern52()], ... num_outputs=6, ... key=jr.key(42) ... )

create_oilmm_from_data

create_oilmm_from_data(
    dataset: Dataset,
    num_latent_gps: int,
    key: Array,
    kernel: AbstractKernel = None,
    mean_function: Any = None,
) -> OILMMModel

Create OILMM with data-informed initialization of mixing matrix.

Initializes U to the top M eigenvectors and S to the top M eigenvalues of the empirical covariance matrix of the outputs. Near-zero eigenvalues are clamped to 1e-6 for numerical stability. This can provide better initialization than random, especially when outputs have clear correlation structure.

Parameters:

  • dataset (Dataset) –

    Training data with y [N, P]

  • num_latent_gps (int) –

    Number of latent GPs (m)

  • key (Array) –

    JAX PRNG key

  • kernel (AbstractKernel, default: None ) –

    Kernel for latent GPs (default: RBF)

  • mean_function (Any, default: None ) –

    Mean function (default: Zero)

Returns:

  • OILMMModel –

    OILMMModel with U initialized to top M eigenvectors and S to

  • OILMMModel –

    top M eigenvalues

Example

import gpjax as gpx import jax.numpy as jnp import jax.random as jr X = jnp.linspace(0, 1, 50).reshape(-1, 1) y = jnp.column_stack([jnp.sin(X), jnp.cos(X)]) data = gpx.Dataset(X=X, y=y) model = gpx.models.create_oilmm_from_data( ... dataset=data, ... num_latent_gps=2, ... key=jr.key(42) ... )