Sparse Gaussian Process Regression¶
In this notebook we consider sparse Gaussian process regression (SGPR) Titsias (2009). This is a solution for medium to large-scale conjugate regression problems. In order to arrive at a computationally tractable method, the approximate posterior is parameterized via a set of $m$ pseudo-points $\boldsymbol{z}$. Critically, the approach leads to $\mathcal{O}(nm^2)$ complexity for approximate maximum likelihood learning and $O(m^2)$ per test point for prediction.
# Enable Float64 for more stable matrix inversions.
from jax import config
config.update("jax_enable_x64", True)
from jax import jit
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import install_import_hook
import matplotlib as mpl
import matplotlib.pyplot as plt
import optax as ox
from docs.examples.utils import clean_legend
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx
key = jr.PRNGKey(123)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
Dataset¶
With the necessary modules imported, we simulate a dataset $\mathcal{D} = (\boldsymbol{x}, \boldsymbol{y}) = \{(x_i, y_i)\}_{i=1}^{500}$ with inputs $\boldsymbol{x}$ sampled uniformly on $(-3., 3)$ and corresponding independent noisy outputs
$$\boldsymbol{y} \sim \mathcal{N} \left(\sin(7\boldsymbol{x}) + x \cos(2 \boldsymbol{x}), \textbf{I} * 0.5^2 \right).$$
We store our data $\mathcal{D}$ as a GPJax Dataset
and create test inputs and
labels for later.
n = 2500
noise = 0.5
key, subkey = jr.split(key)
x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,)).reshape(-1, 1)
f = lambda x: jnp.sin(2 * x) + x * jnp.cos(5 * x)
signal = f(x)
y = signal + jr.normal(subkey, shape=signal.shape) * noise
D = gpx.Dataset(X=x, y=y)
xtest = jnp.linspace(-3.1, 3.1, 500).reshape(-1, 1)
ytest = f(xtest)
To better understand what we have simulated, we plot both the underlying latent function and the observed data that is subject to Gaussian noise. We also plot an initial set of inducing points over the space.
n_inducing = 50
z = jnp.linspace(-3.0, 3.0, n_inducing).reshape(-1, 1)
fig, ax = plt.subplots()
ax.scatter(x, y, alpha=0.25, label="Observations", color=cols[0])
ax.plot(xtest, ytest, label="Latent function", linewidth=2, color=cols[1])
ax.vlines(
x=z,
ymin=y.min(),
ymax=y.max(),
alpha=0.3,
linewidth=0.5,
label="Inducing point",
color=cols[2],
)
ax.legend(loc="best")
plt.show()
Next we define the true posterior model for the data - note that whilst we can define this, it is intractable to evaluate.
meanf = gpx.mean_functions.Constant()
kernel = gpx.kernels.RBF()
likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
posterior = prior * likelihood
We now define the SGPR model through CollapsedVariationalGaussian
. Through a
set of inducing points $\boldsymbol{z}$ this object builds an approximation to the
true posterior distribution. Consequently, we pass the true posterior and initial
inducing points into the constructor as arguments.
q = gpx.variational_families.CollapsedVariationalGaussian(
posterior=posterior, inducing_inputs=z
)
We define our variational inference algorithm through CollapsedVI
. This defines
the collapsed variational free energy bound considered in
Titsias (2009).
elbo = gpx.objectives.CollapsedELBO(negative=True)
For researchers, GPJax has the capacity to print the bibtex citation for objects such
as the ELBO through the cite()
function.
print(gpx.cite(elbo))
@inproceedings{titsias2009variational, authors = {Titsias, Michalis}, title = {Variational learning of inducing variables in sparse Gaussian processes}, year = {2009}, booktitle = {International Conference on Artificial Intelligence and Statistics}, }
JIT-compiling expensive-to-compute functions such as the ELBO is
advisable. This can be achieved by wrapping the function in jax.jit()
.
elbo = jit(elbo)
We now train our model akin to a Gaussian process regression model via the fit
abstraction. Unlike the regression example given in the
conjugate regression notebook,
the inducing locations that induce our variational posterior distribution are now
part of the model's parameters. Using a gradient-based optimiser, we can then
optimise their location such that the evidence lower bound is maximised.
opt_posterior, history = gpx.fit(
model=q,
objective=elbo,
train_data=D,
optim=ox.adamw(learning_rate=1e-2),
num_iters=500,
key=key,
)
0%| | 0/500 [00:00<?, ?it/s]
fig, ax = plt.subplots()
ax.plot(history, color=cols[1])
ax.set(xlabel="Training iterate", ylabel="ELBO")
[Text(0.5, 0, 'Training iterate'), Text(0, 0.5, 'ELBO')]
We show predictions of our model with the learned inducing points overlaid in grey.
latent_dist = opt_posterior(xtest, train_data=D)
predictive_dist = opt_posterior.posterior.likelihood(latent_dist)
inducing_points = opt_posterior.inducing_inputs
samples = latent_dist.sample(seed=key, sample_shape=(20,))
predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()
fig, ax = plt.subplots()
ax.plot(x, y, "x", label="Observations", color=cols[0], alpha=0.1)
ax.plot(
xtest,
ytest,
label="Latent function",
color=cols[1],
linestyle="-",
linewidth=1,
)
ax.plot(xtest, predictive_mean, label="Predictive mean", color=cols[1])
ax.fill_between(
xtest.squeeze(),
predictive_mean - 2 * predictive_std,
predictive_mean + 2 * predictive_std,
alpha=0.2,
color=cols[1],
label="Two sigma",
)
ax.plot(
xtest,
predictive_mean - 2 * predictive_std,
color=cols[1],
linestyle="--",
linewidth=0.5,
)
ax.plot(
xtest,
predictive_mean + 2 * predictive_std,
color=cols[1],
linestyle="--",
linewidth=0.5,
)
ax.vlines(
x=inducing_points,
ymin=ytest.min(),
ymax=ytest.max(),
alpha=0.3,
linewidth=0.5,
label="Inducing point",
color=cols[2],
)
ax.legend()
ax.set(xlabel=r"$x$", ylabel=r"$f(x)$")
plt.show()
Runtime comparison¶
Given the size of the data being considered here, inference in a GP with a full-rank covariance matrix is possible, albeit quite slow. We can therefore compare the speedup that we get from using the above sparse approximation with corresponding bound on the marginal log-likelihood against the marginal log-likelihood in the full model.
full_rank_model = gpx.gps.Prior(
mean_function=gpx.mean_functions.Zero(), kernel=gpx.kernels.RBF()
) * gpx.likelihoods.Gaussian(num_datapoints=D.n)
negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True).step)
%timeit negative_mll(full_rank_model, D).block_until_ready()
528 ms ± 9.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
negative_elbo = jit(gpx.objectives.CollapsedELBO(negative=True).step)
%timeit negative_elbo(q, D).block_until_ready()
1.79 ms ± 162 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
As we can see, the sparse approximation given here is around 50 times faster when compared against a full-rank model.
System configuration¶
%reload_ext watermark
%watermark -n -u -v -iv -w -a 'Daniel Dodd'
Author: Daniel Dodd Last updated: Sun Dec 03 2023 Python implementation: CPython Python version : 3.10.13 IPython version : 8.17.2 matplotlib: 3.8.1 jax : 0.4.20 gpjax : 0.8.0 optax : 0.1.7 Watermark: 2.4.3