In this notebook we demonstrate how to implement sparse variational Gaussian
processes (SVGPs) of
Hensman et al. (2015). In
particular, this approximation framework provides a tractable option for working with
non-conjugate Gaussian processes with more than ~5000 data points. However, for
conjugate models of less than 5000 data points, we recommend using the marginal
log-likelihood approach presented in the
regression notebook.
Though we illustrate SVGPs here with a conjugate regression example, the same GPJax
code works for general likelihoods, such as a Bernoulli for classification.
# Enable Float64 for more stable matrix inversions.fromjaximportconfigimportjax.numpyasjnpimportjax.randomasjrfromjaxtypingimportinstall_import_hookimportmatplotlibasmplimportmatplotlib.pyplotaspltimportoptaxasoxfromexamples.utilsimportuse_mpl_styleconfig.update("jax_enable_x64",True)withinstall_import_hook("gpjax","beartype.beartype"):importgpjaxasgpximportgpjax.kernelsasjkfromgpjax.parametersimportParameterkey=jr.key(123)# set the default style for plottinguse_mpl_style()cols=mpl.rcParams["axes.prop_cycle"].by_key()["color"]
Dataset
With the necessary modules imported, we simulate a dataset
D=(x,y)={(xiβ,yiβ)}i=15000β
with inputs x sampled uniformly on (β5,5) and corresponding binary outputs
yβΌN(sin(4βx)+sin(2βx),Iβ(0.2)2).
We store our data D as a GPJax Dataset and create test inputs for later.
Despite their endowment with elegant theoretical properties, GPs are burdened with
prohibitive O(n3) inference and O(n2) memory costs in the
number of data points n due to the necessity of computing inverses and determinants
of the kernel Gram matrix Kxxβ during inference
and hyperparameter learning.
Sparse GPs seek to resolve tractability through low-rank approximations.
Their name originates with the idea of using subsets of the data to approximate the
kernel matrix, with sparseness occurring through the selection of the data points.
Given inputs x and outputs y the task was to select an
m<n lower-dimensional dataset (z,y~β)β(x,y)
to train a Gaussian process on instead.
By generalising the set of selected points z, known as
inducing inputs, to remove the restriction of being part of the dataset,
we can arrive at a flexible low-rank approximation framework of the model using
functions of Kzzβ to replace the true
covariance matrix Kxxβ at significantly
lower costs. For example,
review many popular approximation schemes in this vein. However, because the model
and the approximation are intertwined, assigning performance and faults to one or the
other becomes tricky.
On the other hand, sparse variational Gaussian processes (SVGPs)
approximate the posterior, not the model.
These provide a low-rank approximation scheme via variational inference. Here we
posit a family of densities parameterised by "variational parameters".
We then seek to find the closest family member to the posterior by minimising the
Kullback-Leibler divergence over the variational parameters.
The fitted variational density then serves as a proxy for the exact posterior.
This procedure makes variational methods efficiently solvable via off-the-shelf
optimisation techniques whilst retaining the true-underlying model.
Furthermore, SVGPs offer further cost reductions with mini-batch stochastic gradient
descent and address non-conjugacy
.
We show a cost comparison between the approaches below, where b is the mini-batch
size.
GPs
sparse GPs
SVGP
Inference cost
O(n3)
O(nm2)
O(bm2+m3)
Memory cost
O(n2)
O(nm)
O(bm+m2)
To apply SVGP inference to our dataset, we begin by initialising m=50 equally
spaced inducing inputs z across our observed data's support. These
are depicted below via horizontal black lines.
The inducing inputs will summarise our dataset, and since they are treated as
variational parameters, their locations will be optimised. The next step to SVGP is
to define a variational family.
Defining the variational process
We begin by considering the form of the posterior distribution for all function
values f(β )
To arrive at an approximation framework, we assume some redundancy in the data.
Instead of predicting f(β ) with function values at the datapoints
f(x), we assume this can be achieved with only function values at
m inducing inputs z
This lower dimensional integral results in computational savings in the model's
predictive component from p(f(β )β£f(x)) to
p(f(β )β£f(z)) where inverting
Kxxβ is replaced with inverting
Kzzβ.
However, since we did not observe our data D at z we ask,
what exactly is the posterior p(f(z)β£D)?
Notice this is simply obtained by substituting z into (β ),
but we arrive back at square one with computing the expensive integral. To side-step
this, we consider replacing p(f(z)β£D) in (β) with a
cheap-to-compute approximate distribution q(f(z))
q(f(β ))=β«p(f(β )β£f(z))q(f(z))df(z).(Γ)
To measure the quality of the approximation, we consider the Kullback-Leibler
divergence KL(β β£β£β ) from our approximate process
q(f(β )) to the true process p(f(β )β£D). By parametrising
q(f(z)) over a variational family of distributions, we can optimise
Kullback-Leibler divergence with respect to the variational parameters. Moreover,
since inducing input locations z augment the model, they themselves
can be treated as variational parameters without altering the true underlying model
p(f(z)β£D). This is exactly what gives SVGPs great
flexibility whilst retaining robustness to overfitting.
It is popular to elect a Gaussian variational distribution
q(f(z))=N(f(z);m,S)
with parameters {z,m,S}, since conjugacy is
provided between q(f(z)) and p(f(β )β£f(z)) so that
the resulting variational process q(f(β )) is a GP. We can implement this in
GPJax by the following.
Here, the variational process q(β ) depends on the prior through
p(f(β )β£f(z)) in (Γ).
Inference
Evidence lower bound
With our model defined, we seek to infer the optimal inducing inputs
z, variational mean m and covariance
S that define our approximate posterior. To achieve this, we maximise the
evidence lower bound (ELBO) with respect to
{z,m,S}, a proxy for minimising the
Kullback-Leibler divergence. Moreover, as hinted by its name, the ELBO is a lower
bound to the marginal log-likelihood, providing a tractable objective to optimise the
model's hyperparameters akin to the conjugate setting. For further details on this,
see Sections 3.1 and 4.1 of the excellent review paper
.
Mini-batching
Despite introducing inducing inputs into our model, inference can still be
intractable with large datasets. To circumvent this, optimisation can be done using
stochastic mini-batches.
schedule=ox.warmup_cosine_decay_schedule(init_value=0.0,peak_value=0.02,warmup_steps=75,decay_steps=2000,end_value=0.001,)opt_posterior,history=gpx.fit(model=q,# we are minimizing the elbo so we negate itobjective=lambdap,d:-gpx.objectives.elbo(p,d),train_data=D,optim=ox.adam(learning_rate=schedule),num_iters=3000,key=jr.key(42),batch_size=128,trainable=Parameter,)
0%| | 0/3000 [00:00<?, ?it/s]
Predictions
With optimisation complete, we can use our inferred parameter set to make
predictions at novel inputs akin
to all other models within GPJax on our variational process object q(β ) (for
example, see the
regression notebook).