In this notebook we consider sparse Gaussian process regression (SGPR)
Titsias (2009). This is a solution for
medium to large-scale conjugate regression problems.
In order to arrive at a computationally tractable method, the approximate posterior
is parameterized via a set of m pseudo-points z. Critically, the
approach leads to O(nm2) complexity for approximate maximum likelihood
learning and O(m2) per test point for prediction.
# Enable Float64 for more stable matrix inversions.fromjaximport(config,jit,)importjax.numpyasjnpimportjax.randomasjrfromjaxtypingimportinstall_import_hookimportmatplotlibasmplimportmatplotlib.pyplotaspltimportoptaxasoxfromexamples.utilsimportuse_mpl_styleconfig.update("jax_enable_x64",True)withinstall_import_hook("gpjax","beartype.beartype"):importgpjaxasgpxfromgpjax.parametersimportParameter# set the default style for plottinguse_mpl_style()key=jr.key(42)cols=mpl.rcParams["axes.prop_cycle"].by_key()["color"]
Dataset
With the necessary modules imported, we simulate a dataset
D=(x,y)={(xiβ,yiβ)}i=1500β
with inputs x sampled uniformly on (β3.,3) and corresponding
independent noisy outputs
yβΌN(sin(7x)+xcos(2x),Iβ0.52).
We store our data D as a GPJax Dataset and create test inputs and
labels for later.
To better understand what we have simulated, we plot both the underlying latent
function and the observed data that is subject to Gaussian noise. We also plot an
initial set of inducing points over the space.
We now define the SGPR model through CollapsedVariationalGaussian. Through a
set of inducing points z this object builds an approximation to the
true posterior distribution. Consequently, we pass the true posterior and initial
inducing points into the constructor as arguments.
We now train our model akin to a Gaussian process regression model via the fit
abstraction. Unlike the regression example given in the
conjugate regression notebook,
the inducing locations that induce our variational posterior distribution are now
part of the model's parameters. Using a gradient-based optimiser, we can then
optimise their location such that the evidence lower bound is maximised.
# Use the enhanced fit API with trainable parameter filteringopt_posterior,history=gpx.fit(model=q,# we want want to minimize the *negative* ELBOobjective=lambdap,d:-gpx.objectives.collapsed_elbo(p,d),train_data=D,optim=ox.adamw(learning_rate=1e-2),num_iters=500,key=key,trainable=Parameter,)
Given the size of the data being considered here, inference in a GP with a full-rank
covariance matrix is possible, albeit quite slow. We can therefore compare the
speedup that we get from using the above sparse approximation with corresponding
bound on the marginal log-likelihood against the marginal log-likelihood in the
full model.