Sparse Stochastic Variational Inference
In this notebook we demonstrate how to implement sparse variational Gaussian processes (SVGPs) of Hensman et al. (2015). In particular, this approximation framework provides a tractable option for working with non-conjugate Gaussian processes with more than ~5000 data points. However, for conjugate models of less than 5000 data points, we recommend using the marginal log-likelihood approach presented in the regression notebook. Though we illustrate SVGPs here with a conjugate regression example, the same GPJax code works for general likelihoods, such as a Bernoulli for classification.
# Enable Float64 for more stable matrix inversions.
from jax import config
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import install_import_hook
import matplotlib as mpl
import matplotlib.pyplot as plt
import optax as ox
import tensorflow_probability.substrates.jax as tfp
from examples.utils import use_mpl_style
config.update("jax_enable_x64", True)
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx
import gpjax.kernels as jk
tfb = tfp.bijectors
key = jr.key(123)
# set the default style for plotting
use_mpl_style()
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
Dataset
With the necessary modules imported, we simulate a dataset \(\mathcal{D} = (\boldsymbol{x}, \boldsymbol{y}) = \{(x_i, y_i)\}_{i=1}^{5000}\) with inputs \(\boldsymbol{x}\) sampled uniformly on \((-5, 5)\) and corresponding binary outputs
We store our data \(\mathcal{D}\) as a GPJax Dataset
and create test inputs for later.
n = 50000
noise = 0.2
key, subkey = jr.split(key)
x = jr.uniform(key=key, minval=-5.0, maxval=5.0, shape=(n,)).reshape(-1, 1)
f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)
signal = f(x)
y = signal + jr.normal(subkey, shape=signal.shape) * noise
D = gpx.Dataset(X=x, y=y)
xtest = jnp.linspace(-5.5, 5.5, 500).reshape(-1, 1)
Sparse GPs via inducing inputs
Despite their endowment with elegant theoretical properties, GPs are burdened with prohibitive \(\mathcal{O}(n^3)\) inference and \(\mathcal{O}(n^2)\) memory costs in the number of data points \(n\) due to the necessity of computing inverses and determinants of the kernel Gram matrix \(\mathbf{K}_{\boldsymbol{x}\boldsymbol{x}}\) during inference and hyperparameter learning. Sparse GPs seek to resolve tractability through low-rank approximations.
Their name originates with the idea of using subsets of the data to approximate the kernel matrix, with sparseness occurring through the selection of the data points. Given inputs \(\boldsymbol{x}\) and outputs \(\boldsymbol{y}\) the task was to select an \(m<n\) lower-dimensional dataset \((\boldsymbol{z},\boldsymbol{\tilde{y}}) \subset (\boldsymbol{x}, \boldsymbol{y})\) to train a Gaussian process on instead. By generalising the set of selected points \(\boldsymbol{z}\), known as inducing inputs, to remove the restriction of being part of the dataset, we can arrive at a flexible low-rank approximation framework of the model using functions of \(\mathbf{K}_{\boldsymbol{z}\boldsymbol{z}}\) to replace the true covariance matrix \(\mathbf{K}_{\boldsymbol{x}\boldsymbol{x}}\) at significantly lower costs. For example, review many popular approximation schemes in this vein. However, because the model and the approximation are intertwined, assigning performance and faults to one or the other becomes tricky.
On the other hand, sparse variational Gaussian processes (SVGPs) approximate the posterior, not the model. These provide a low-rank approximation scheme via variational inference. Here we posit a family of densities parameterised by “variational parameters”. We then seek to find the closest family member to the posterior by minimising the Kullback-Leibler divergence over the variational parameters. The fitted variational density then serves as a proxy for the exact posterior. This procedure makes variational methods efficiently solvable via off-the-shelf optimisation techniques whilst retaining the true-underlying model. Furthermore, SVGPs offer further cost reductions with mini-batch stochastic gradient descent and address non-conjugacy . We show a cost comparison between the approaches below, where \(b\) is the mini-batch size.
GPs | sparse GPs | SVGP | |
---|---|---|---|
Inference cost | \(\mathcal{O}(n^3)\) | \(\mathcal{O}(n m^2)\) | \(\mathcal{O}(b m^2 + m^3)\) |
Memory cost | \(\mathcal{O}(n^2)\) | \(\mathcal{O}(n m)\) | \(\mathcal{O}(b m + m^2)\) |
To apply SVGP inference to our dataset, we begin by initialising \(m = 50\) equally spaced inducing inputs \(\boldsymbol{z}\) across our observed data's support. These are depicted below via horizontal black lines.
z = jnp.linspace(-5.0, 5.0, 50).reshape(-1, 1)
fig, ax = plt.subplots()
ax.vlines(
z,
ymin=y.min(),
ymax=y.max(),
alpha=0.3,
linewidth=1,
label="Inducing point",
color=cols[2],
)
ax.scatter(x, y, alpha=0.2, color=cols[0], label="Observations")
ax.plot(xtest, f(xtest), color=cols[1], label="Latent function")
ax.legend()
ax.set(xlabel=r"$x$", ylabel=r"$f(x)$")
[Text(0.5, 0, '$x$'), Text(0, 0.5, '$f(x)$')]
The inducing inputs will summarise our dataset, and since they are treated as variational parameters, their locations will be optimised. The next step to SVGP is to define a variational family.
Defining the variational process
We begin by considering the form of the posterior distribution for all function values \(f(\cdot)\)
To arrive at an approximation framework, we assume some redundancy in the data. Instead of predicting \(f(\cdot)\) with function values at the datapoints \(f(\boldsymbol{x})\), we assume this can be achieved with only function values at \(m\) inducing inputs \(\boldsymbol{z}\)
This lower dimensional integral results in computational savings in the model's predictive component from \(p(f(\cdot)|f(\boldsymbol{x}))\) to \(p(f(\cdot)|f(\boldsymbol{z}))\) where inverting \(\mathbf{K}_{\boldsymbol{x}\boldsymbol{x}}\) is replaced with inverting \(\mathbf{K}_{\boldsymbol{z}\boldsymbol{z}}\). However, since we did not observe our data \(\mathcal{D}\) at \(\boldsymbol{z}\) we ask, what exactly is the posterior \(p(f(\boldsymbol{z})|\mathcal{D})\)?
Notice this is simply obtained by substituting \(\boldsymbol{z}\) into \((\dagger)\), but we arrive back at square one with computing the expensive integral. To side-step this, we consider replacing \(p(f(\boldsymbol{z})|\mathcal{D})\) in \((\star)\) with a cheap-to-compute approximate distribution \(q(f(\boldsymbol{z}))\)
$$ q(f(\cdot)) = \int p(f(\cdot)|f(\boldsymbol{z})) q(f(\boldsymbol{z})) \text{d}f(\boldsymbol{z}). \qquad (\times) $$
To measure the quality of the approximation, we consider the Kullback-Leibler divergence \(\operatorname{KL}(\cdot || \cdot)\) from our approximate process \(q(f(\cdot))\) to the true process \(p(f(\cdot)|\mathcal{D})\). By parametrising \(q(f(\boldsymbol{z}))\) over a variational family of distributions, we can optimise Kullback-Leibler divergence with respect to the variational parameters. Moreover, since inducing input locations \(\boldsymbol{z}\) augment the model, they themselves can be treated as variational parameters without altering the true underlying model \(p(f(\boldsymbol{z})|\mathcal{D})\). This is exactly what gives SVGPs great flexibility whilst retaining robustness to overfitting.
It is popular to elect a Gaussian variational distribution \(q(f(\boldsymbol{z})) = \mathcal{N}(f(\boldsymbol{z}); \mathbf{m}, \mathbf{S})\) with parameters \(\{\boldsymbol{z}, \mathbf{m}, \mathbf{S}\}\), since conjugacy is provided between \(q(f(\boldsymbol{z}))\) and \(p(f(\cdot)|f(\boldsymbol{z}))\) so that the resulting variational process \(q(f(\cdot))\) is a GP. We can implement this in GPJax by the following.
meanf = gpx.mean_functions.Zero()
likelihood = gpx.likelihoods.Gaussian(num_datapoints=n)
kernel = jk.RBF() # 1-dimensional inputs
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
p = prior * likelihood
q = gpx.variational_families.VariationalGaussian(posterior=p, inducing_inputs=z)
Here, the variational process \(q(\cdot)\) depends on the prior through \(p(f(\cdot)|f(\boldsymbol{z}))\) in \((\times)\).
Inference
Evidence lower bound
With our model defined, we seek to infer the optimal inducing inputs \(\boldsymbol{z}\), variational mean \(\mathbf{m}\) and covariance \(\mathbf{S}\) that define our approximate posterior. To achieve this, we maximise the evidence lower bound (ELBO) with respect to \(\{\boldsymbol{z}, \mathbf{m}, \mathbf{S} \}\), a proxy for minimising the Kullback-Leibler divergence. Moreover, as hinted by its name, the ELBO is a lower bound to the marginal log-likelihood, providing a tractable objective to optimise the model's hyperparameters akin to the conjugate setting. For further details on this, see Sections 3.1 and 4.1 of the excellent review paper .
Mini-batching
Despite introducing inducing inputs into our model, inference can still be intractable with large datasets. To circumvent this, optimisation can be done using stochastic mini-batches.
schedule = ox.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=0.02,
warmup_steps=75,
decay_steps=2000,
end_value=0.001,
)
opt_posterior, history = gpx.fit(
model=q,
# we are minimizing the elbo so we negate it
objective=lambda p, d: -gpx.objectives.elbo(p, d),
train_data=D,
optim=ox.adam(learning_rate=schedule),
num_iters=3000,
key=jr.key(42),
batch_size=128,
)
0%| | 0/3000 [00:00<?, ?it/s]
Predictions
With optimisation complete, we can use our inferred parameter set to make predictions at novel inputs akin to all other models within GPJax on our variational process object \(q(\cdot)\) (for example, see the regression notebook).
latent_dist = opt_posterior(xtest)
predictive_dist = opt_posterior.posterior.likelihood(latent_dist)
meanf = predictive_dist.mean()
sigma = predictive_dist.stddev()
fig, ax = plt.subplots()
ax.scatter(x, y, alpha=0.15, label="Training Data", color=cols[0])
ax.plot(xtest, meanf, label="Posterior mean", color=cols[1])
ax.fill_between(
xtest.flatten(),
meanf - 2 * sigma,
meanf + 2 * sigma,
alpha=0.3,
color=cols[1],
label="Two sigma",
)
ax.vlines(
opt_posterior.inducing_inputs.value,
ymin=y.min(),
ymax=y.max(),
alpha=0.3,
linewidth=1,
label="Inducing point",
color=cols[2],
)
ax.legend()
<matplotlib.legend.Legend at 0x7f38bc631810>
Custom transformations
To train a covariance matrix, GPJax uses tfb.FillTriangular
transformation by
default. tfb.FillTriangular
fills a 1d vector into a lower triangular matrix.
Users can change this default transformation
with another valid transformation of their choice. For example, Square
transformation on the diagonal can also serve the purpose.
params_bijection = gpx.parameters.DEFAULT_BIJECTION.copy()
params_bijection[gpx.parameters.LowerTriangular] = tfb.FillScaleTriL(
diag_bijector=tfb.Square(), diag_shift=jnp.array(q.jitter)
)
opt_rep, history = gpx.fit(
model=q,
objective=lambda p, d: -gpx.objectives.elbo(p, d),
train_data=D,
optim=ox.adam(learning_rate=0.01),
num_iters=3000,
key=jr.key(42),
batch_size=128,
)
0%| | 0/3000 [00:00<?, ?it/s]
latent_dist = opt_rep(xtest)
predictive_dist = opt_rep.posterior.likelihood(latent_dist)
meanf = predictive_dist.mean()
sigma = predictive_dist.stddev()
fig, ax = plt.subplots()
ax.scatter(x, y, alpha=0.15, label="Training Data", color=cols[0])
ax.plot(xtest, meanf, label="Posterior mean", color=cols[1])
ax.fill_between(
xtest.flatten(),
meanf - 2 * sigma,
meanf + 2 * sigma,
alpha=0.3,
color=cols[1],
label="Two sigma",
)
ax.vlines(
opt_rep.inducing_inputs.value,
ymin=y.min(),
ymax=y.max(),
alpha=0.3,
linewidth=1,
label="Inducing point",
color=cols[2],
)
ax.legend()
<matplotlib.legend.Legend at 0x7f38bc23f2e0>
We can see that Square
transformation is able to get relatively better fit
compared to Softplus
with the same number of iterations, but Softplus
is
recommended over Square
for stability of optimisation.
System configuration
Author: Thomas Pinder, Daniel Dodd & Zeel B Patel
Last updated: Tue Nov 26 2024
Python implementation: CPython
Python version : 3.10.15
IPython version : 8.29.0
jaxtyping : 0.2.36
optax : 0.2.4
tensorflow_probability: 0.25.0
jax : 0.4.27
matplotlib : 3.9.2
gpjax : 0.9.3
Watermark: 2.5.0