Regression
In this notebook we demonstrate how to fit a Gaussian process regression model.
# Enable Float64 for more stable matrix inversions.
from jax import config
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import install_import_hook
import matplotlib as mpl
import matplotlib.pyplot as plt
from examples.utils import (
clean_legend,
use_mpl_style,
)
config.update("jax_enable_x64", True)
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx
key = jr.key(123)
# set the default style for plotting
use_mpl_style()
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
Dataset
With the necessary modules imported, we simulate a dataset \(\mathcal{D} = (\boldsymbol{x}, \boldsymbol{y}) = \{(x_i, y_i)\}_{i=1}^{100}\) with inputs \(\boldsymbol{x}\) sampled uniformly on \((-3., 3)\) and corresponding independent noisy outputs
We store our data \(\mathcal{D}\) as a GPJax Dataset
and create test inputs and labels
for later.
n = 100
noise = 0.3
key, subkey = jr.split(key)
x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,)).reshape(-1, 1)
f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)
signal = f(x)
y = signal + jr.normal(subkey, shape=signal.shape) * noise
D = gpx.Dataset(X=x, y=y)
xtest = jnp.linspace(-3.5, 3.5, 500).reshape(-1, 1)
ytest = f(xtest)
To better understand what we have simulated, we plot both the underlying latent function and the observed data that is subject to Gaussian noise.
fig, ax = plt.subplots()
ax.plot(x, y, "o", label="Observations", color=cols[0])
ax.plot(xtest, ytest, label="Latent function", color=cols[1])
ax.legend(loc="best")
<matplotlib.legend.Legend at 0x7f6c68234670>
Our aim in this tutorial will be to reconstruct the latent function from our noisy observations \(\mathcal{D}\) via Gaussian process regression. We begin by defining a Gaussian process prior in the next section.
Defining the prior
A zero-mean Gaussian process (GP) places a prior distribution over real-valued functions \(f(\cdot)\) where \(f(\boldsymbol{x}) \sim \mathcal{N}(0, \mathbf{K}_{\boldsymbol{x}\boldsymbol{x}})\) for any finite collection of inputs \(\boldsymbol{x}\).
Here \(\mathbf{K}_{\boldsymbol{x}\boldsymbol{x}}\) is the Gram matrix generated by a user-specified symmetric, non-negative definite kernel function \(k(\cdot, \cdot')\) with \([\mathbf{K}_{\boldsymbol{x}\boldsymbol{x}}]_{i, j} = k(x_i, x_j)\). The choice of kernel function is critical as, among other things, it governs the smoothness of the outputs that our GP can generate.
For simplicity, we consider a radial basis function (RBF) kernel:
On paper a GP is written as \(f(\cdot) \sim \mathcal{GP}(\textbf{0}, k(\cdot, \cdot'))\),
we can reciprocate this process in GPJax via defining a Prior
with our chosen RBF
kernel.
kernel = gpx.kernels.RBF() # 1-dimensional input
meanf = gpx.mean_functions.Zero()
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
The above construction forms the foundation for GPJax's models. Moreover, the GP prior we have just defined can be represented by a TensorFlow Probability multivariate Gaussian distribution. Such functionality enables trivial sampling, and the evaluation of the GP's mean and covariance .
prior_dist = prior.predict(xtest)
prior_mean = prior_dist.mean()
prior_std = prior_dist.variance()
samples = prior_dist.sample(seed=key, sample_shape=(20,))
fig, ax = plt.subplots()
ax.plot(xtest, samples.T, alpha=0.5, color=cols[0], label="Prior samples")
ax.plot(xtest, prior_mean, color=cols[1], label="Prior mean")
ax.fill_between(
xtest.flatten(),
prior_mean - prior_std,
prior_mean + prior_std,
alpha=0.3,
color=cols[1],
label="Prior variance",
)
ax.legend(loc="best")
ax = clean_legend(ax)
Constructing the posterior
Having defined our GP, we proceed to define a description of our data \(\mathcal{D}\) conditional on our knowledge of \(f(\cdot)\) --- this is exactly the notion of a likelihood function \(p(\mathcal{D} | f(\cdot))\). While the choice of likelihood is a critical in Bayesian modelling, for simplicity we consider a Gaussian with noise parameter \(\alpha\)
This is defined in GPJax through calling a Gaussian
instance.
The posterior is proportional to the prior multiplied by the likelihood, written as
$$ p(f(\cdot) | \mathcal{D}) \propto p(f(\cdot)) * p(\mathcal{D} | f(\cdot)). $$
Mimicking this construct, the posterior is established in GPJax through the *
operator.
Parameter state
As outlined in the PyTrees documentation, parameters are contained within the model and for the leaves of the PyTree. Consequently, in this particular model, we have three parameters: the kernel lengthscale, kernel variance and the observation noise variance. Whilst we have initialised each of these to 1, we can learn Type 2 MLEs for each of these parameters by optimising the marginal log-likelihood (MLL).
124.80517341248417
We can now define an optimiser. For this example we'll use the bfgs
optimiser.
opt_posterior, history = gpx.fit_scipy(
model=posterior,
# we use the negative mll as we are minimising
objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
train_data=D,
)
print(-gpx.objectives.conjugate_mll(opt_posterior, D))
Optimization terminated successfully.
Current function value: 55.074053
Iterations: 16
Function evaluations: 19
Gradient evaluations: 19
55.07405321945008
Prediction
Equipped with the posterior and a set of optimised hyperparameter values, we are now
in a position to query our GP's predictive distribution at novel test inputs. To do
this, we use our defined posterior
and likelihood
at our test inputs to obtain
the predictive distribution as a Distrax
multivariate Gaussian upon which mean
and stddev
can be used to extract the predictive mean and standard deviatation.
latent_dist = opt_posterior.predict(xtest, train_data=D)
predictive_dist = opt_posterior.likelihood(latent_dist)
predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()
With the predictions and their uncertainty acquired, we illustrate the GP's performance at explaining the data \(\mathcal{D}\) and recovering the underlying latent function of interest.
fig, ax = plt.subplots(figsize=(7.5, 2.5))
ax.plot(x, y, "x", label="Observations", color=cols[0], alpha=0.5)
ax.fill_between(
xtest.squeeze(),
predictive_mean - 2 * predictive_std,
predictive_mean + 2 * predictive_std,
alpha=0.2,
label="Two sigma",
color=cols[1],
)
ax.plot(
xtest,
predictive_mean - 2 * predictive_std,
linestyle="--",
linewidth=1,
color=cols[1],
)
ax.plot(
xtest,
predictive_mean + 2 * predictive_std,
linestyle="--",
linewidth=1,
color=cols[1],
)
ax.plot(
xtest, ytest, label="Latent function", color=cols[0], linestyle="--", linewidth=2
)
ax.plot(xtest, predictive_mean, label="Predictive mean", color=cols[1])
ax.legend(loc="center left", bbox_to_anchor=(0.975, 0.5))
<matplotlib.legend.Legend at 0x7f6c60967d60>
System configuration
Author: Thomas Pinder & Daniel Dodd
Last updated: Tue Nov 26 2024
Python implementation: CPython
Python version : 3.10.15
IPython version : 8.29.0
jaxtyping : 0.2.36
gpjax : 0.9.3
matplotlib: 3.9.2
jax : 0.4.27
Watermark: 2.5.0