The multi-output notebook introduces the Intrinsic
Coregionalisation Model (ICM) and the Linear Model of Coregionalisation (LCM),
which capture cross-output correlations through coregionalisation matrices. Both
approaches form the full joint covariance over all n inputs and p outputs.
In the general case (LCM with Q>1 components) this costs
O((np)3); even the single-component ICM, which enjoys Kronecker
structure, still requires O(n3+p3). When p is large, both
become prohibitive expensive from a computational and memory perspective.
The Orthogonal Instantaneous Linear Mixing Model (OILMM) of
Bruinsma et al. (2020) resolves this
bottleneck. It models the p outputs as linear mixtures of m≤p latent
Gaussian processes through a mixing matrix H whose columns are
mutually orthogonal. This orthogonality causes the projected observation noise
to be diagonal, which in turn allows inference to decompose into m independent
single-output GP problems. The overall cost drops to O(n3m) —
linear in the number of latent processes and entirely independent of p.
This notebook derives the OILMM mathematics step by step, implements a
five-output example in GPJax, optimises the model's parameters via the OILMM
log marginal likelihood, and visualises the model's predictions.
A common formulation for multi-output GPs assumes that the p-dimensional
output vector is generated by linearly mixing m independent latent GPs:
y(t)=Hx(t)+ε(t),
where
- x(t)=(x1(t),…,xm(t))⊤ collects m
independent latent GPs, each with kernel ki,
- H∈Rp×m is the mixing matrix that maps
from latent space to output space, and
- ε(t)∼N(0,σ2Ip) is i.i.d. observation noise.
Given n input locations, we stack all observations into a vector
yˉ∈Rnp. Its joint covariance is
where Ki is the n×n Gram matrix of the i-th latent
kernel. Inverting this np×np matrix naively costs
O(n3p3), which is impractical when p is even moderately large.
The OILMM parameterisation
OILMM constrains H to have orthogonal columns by writing
H=US1/2,
where U∈Rp×m has orthonormal columns
(U⊤U=Im) and
S=diag(s1,…,sm) with each si>0 is a
positive diagonal scaling matrix.
The corresponding projection matrix is the left pseudo-inverse of
H:
T=S−1/2U⊤⟹TH=S−1/2ImU⊤US1/2=Im.
Applying T to the observed outputs projects them into the latent
space:
y~(t)=Ty(t)=ImTHx(t)+Tε(t)=x(t)+ε~(t).
Diagonal projected noise
The crux of OILMM is that the projected noise
ε~=Tε has a
diagonal covariance:
cov[ε~]=σ2TT⊤=σ2S−1/2ImU⊤US−1/2=σ2S−1.
Because S is diagonal, the projected noise components are
independent: the i-th latent observation has noise variance σ2/si.
This is the result that makes OILMM tractable.
GPJax additionally supports per-latent heterogeneous noise
D=diag(d1,…,dm) with each di≥0.
Including this term, the full projected noise covariance is
ΣT=σ2S−1+D,
which remains diagonal.
Independent latent inference
Because the projected noise is diagonal, each projected observation
y~i(t)=xi(t)+ε~i(t) constitutes a standard
single-output GP regression problem with known noise variance
σ2/si+di. We can therefore condition each latent GP independently
using the standard conjugate formulae.
The cost of conditioning one GP on n observations is O(n3)
(dominated by the Cholesky factorisation), so conditioning all m latent GPs
costs O(n3m). This is dramatically cheaper than the
O(n3p3) cost of the general linear mixing model.
Reconstructing predictions in output space
After conditioning, we obtain posterior means μi and
covariances Σi for each latent GP at n∗ test locations.
The output-space predictive distribution is recovered by applying the mixing
matrix:
Each of the five outputs is a different weighted combination of the two
underlying latent functions. We plot them alongside the noiseless signal.
fig,axes=plt.subplots(num_outputs,1,figsize=(10,1.5*num_outputs),sharex=True)forpinrange(num_outputs):plot_output_panel(axes[p],p,X_train,y_train,y_clean,cols)ifp==0:axes[p].legend(loc="upper right",fontsize=7)axes[-1].set_xlabel(r"$t$")plt.suptitle(f"{num_outputs} Observed Outputs from {num_latent} Latent Sources",fontsize=13)
Text(0.5, 0.98, '5 Observed Outputs from 2 Latent Sources')
Constructing the OILMM
GPJax provides create_oilmm_from_data, which initialises the mixing matrix
using the empirical correlation structure of the outputs. Under the hood it
first computes the empirical covariance matrix
Σ^=n1Yc⊤Yc,
where Yc is the column-centred observation matrix. The next step
extracts the top m eigenvectors and eigenvalues of Σ^.
The function then sets Ulatent to the eigenvectors. This
ensures that after SVD orthogonalisation the columns of U align with
the principal directions of output variation. Finally, S is set to
the corresponding eigenvalues giving the scaling an informative starting point.
In this final step, the eigenvalues are clamped to 10−6 for numerical stability.
This is analogous to initialising with PCA: the first m principal components
capture the most variance and provide a reasonable starting point for H.
The OrthogonalMixingMatrix parameter stores an unconstrained matrix
Ulatent∈Rp×m and projects it onto
the Stiefel manifold (the set of matrices with orthonormal columns) at each
forward pass using SVD:
USVD,_,V⊤=SVD(Ulatent),U=USVDV⊤.
This ensures U⊤U=Im exactly, regardless
of the optimiser's updates to the unconstrained representation.
Before optimising any parameters, we condition with the PCA-initialised
defaults to establish a baseline. Calling condition_on_observations executes
the OILMM inference algorithm:
Project: compute
Y~=TY⊤ in
O(nmp).
Condition: for each latent GP i, form a single-output dataset from
y~i with noise variance σ2/si+di, then
condition using the standard conjugate formulae in O(n3).
Return: an OILMMPosterior wrapping the m independent posteriors.
We first inspect the model's output-space predictions using the default
PCA-initialised parameters. This serves as a baseline against which we can
later compare the optimised model.
We optimise the model's parameters by maximising the OILMM log marginal likelihood.
Proposition 9 of Bruinsma et al. (2020) gives the exact expression:
logp(Y)=scaling penalty−2nlog∣S∣residual noise−2n(p−m)log(2πσ2)projection residual−2σ21(Ip−UU⊤)Y⊤F2+i=1∑mlatent GP marginal likelihoodlogN(y~i∣0,Ki+(σ2/si+di)In).
The first three terms are correction factors that account for the
deterministic projection from output space to latent space:
Scaling penalty: penalises very large or small si values, preventing
the model from trivially inflating the likelihood by rescaling.
Residual noise: the log-probability of the (p−m) directions
orthogonal to U, which are explained purely by observation noise.
Projection residual: the squared Frobenius norm of the data component
that lies outside the column space of U, divided by σ2.
The final summation is simply the sum of m standard single-output GP log
marginal likelihoods, each evaluated on the projected data.
GPJax implements this in oilmm_mll(model, data), which takes the
pre-conditioning OILMMModel (not a posterior) together with the training
Dataset. We negate it for minimisation with fit_scipy.
We maximise the OILMM log marginal likelihood using L-BFGS via fit_scipy.
The optimiser tunes all Parameter leaves: the kernel hyperparameters
of each latent GP, the unconstrained mixing matrix Ulatent,
the diagonal scaling S, and the noise variances (σ2 and D).
/home/runner/work/GPJax/GPJax/.venv/lib/python3.11/site-packages/scipy/optimize/_optimize.py:1474: RuntimeWarning: invalid value encountered in scalar multiply
if (alpha_k*vecnorm(pk) <= xrtol*(xrtol + vecnorm(xk))):
/home/runner/work/GPJax/GPJax/.venv/lib/python3.11/site-packages/scipy/optimize/_optimize.py:1474: RuntimeWarning: invalid value encountered in scalar multiply
if (alpha_k*vecnorm(pk) <= xrtol*(xrtol + vecnorm(xk))):
/home/runner/work/GPJax/GPJax/.venv/lib/python3.11/site-packages/scipy/optimize/_optimize.py:1474: RuntimeWarning: invalid value encountered in scalar multiply
if (alpha_k*vecnorm(pk) <= xrtol*(xrtol + vecnorm(xk))):
/home/runner/work/GPJax/GPJax/.venv/lib/python3.11/site-packages/scipy/optimize/_minimize.py:779: OptimizeWarning: Desired error not necessarily achieved due to precision loss.
res = _minimize_bfgs(fun, x0, args, jac, callback, **options)
Current function value: -75.275478
Iterations: 191
Function evaluations: 329
Gradient evaluations: 309
Initial MLL: -500.240
Optimised MLL: 75.275
Post-optimisation predictions
We re-condition the optimised model on the training data, then predict at the
same test locations.
We display the baseline (left) and optimised (right) predictions side by side
for each output. The grey dashed line shows the noiseless ground-truth signal.
fig,axes=plt.subplots(num_outputs,2,figsize=(14,1.8*num_outputs),sharex=True)forpinrange(num_outputs):forj,(mean,std,title)inenumerate([(pre_opt_mean,pre_opt_std,"Before Optimisation"),(post_opt_mean,post_opt_std,"After Optimisation"),]):plot_output_panel(axes[p,j],p,X_train,y_train,y_clean,cols,X_test,mean,func_std=std)ifp==0:axes[p,j].set_title(title,fontsize=11)axes[-1,0].set_xlabel(r"$t$")axes[-1,1].set_xlabel(r"$t$")plt.suptitle("OILMM Predictions: Default Parameters vs Optimised",fontsize=13,y=1.01)
Text(0.5, 1.01, 'OILMM Predictions: Default Parameters vs Optimised')
Predictive uncertainty: latent function vs noisy observations
The previous figure shows uncertainty over the latent noise-free function
f(t)=Hx(t). To visualise uncertainty over
observed outputs y(t), we add output-space noise:
var[yj(t)]=var[fj(t)]+σ2+i=1∑mHji2di.
The wider band below is the predictive standard devatiation of the noisy observations,
while the narrower band is the latent function's standard deviation.
fig,axes=plt.subplots(num_outputs,2,figsize=(14,1.8*num_outputs),sharex=True)forpinrange(num_outputs):forj,(mean,fstd,ostd,title)inenumerate([(pre_opt_mean,pre_opt_std,pre_obs_std,"Before Optimisation"),(post_opt_mean,post_opt_std,post_obs_std,"After Optimisation"),]):plot_output_panel(axes[p,j],p,X_train,y_train,y_clean,cols,X_test,mean,func_std=fstd,obs_std=ostd,)ifp==0:axes[p,j].set_title(title,fontsize=11)ifp==0andj==1:axes[p,j].legend(loc="upper right",fontsize=7)axes[-1,0].set_xlabel(r"$t$")axes[-1,1].set_xlabel(r"$t$")plt.suptitle("OILMM Predictive Intervals: Latent Function vs Noisy Observations",fontsize=13,y=1.01,)
Text(0.5, 1.01, 'OILMM Predictive Intervals: Latent Function vs Noisy Observations')
Latent space after optimisation
The decomposition into independent single-output problems is the heart of
OILMM. We plot each latent GP's projected training data
y~i alongside its posterior predictive distribution. These
m panels are completely independent — each latent GP knows nothing about the
others.
Note that the latent functions are identified only up to a sign/scale ambiguity:
flipping the sign of a column of U and the corresponding latent GP
leaves the output-space predictions unchanged.
fig,axes=plt.subplots(1,num_latent,figsize=(5*num_latent,3),sharey=False)foriinrange(num_latent):ax=axes[i]lat_y=opt_posterior.latent_datasets[i].y.squeeze()ax.plot(X_train,lat_y,"o",color=cols[i],alpha=0.4,ms=3,label="Projected data")lat_pred=opt_posterior.latent_posteriors[i].predict(X_test,train_data=opt_posterior.latent_datasets[i])lat_mean=lat_pred.meanlat_std=jnp.sqrt(jnp.diag(lat_pred.covariance()))ax.plot(X_test,lat_mean,color=cols[i],linewidth=2,label="Posterior mean")ax.fill_between(X_test.squeeze(),lat_mean-2*lat_std,lat_mean+2*lat_std,color=cols[i],alpha=0.2,label="Two sigma",)ax.set_xlabel(r"$t$")ax.set_title(f"Latent GP {i+1}")ax.legend(loc="best",fontsize=7)
Each latent GP has learned a smooth function that explains the projected
observations. The prediction intervals are narrow where data is dense and widen
towards the boundaries. Crucially, these m regression problems were solved
independently, with total cost O(n3m).
Heterogeneous kernels
By default, passing a single kernel to OILMMModel deep-copies it m times
so that each latent GP has independent hyperparameters. If the latent
processes operate at fundamentally different characteristic scales, you can go
further and assign entirely different kernel families:
The first latent GP would then use an infinitely differentiable RBF kernel
while the second uses the rougher Matern-5/2. This is analogous to the
advantage of LCM over ICM in the multi-output setting, where different
components can capture different spectral characteristics.