Introduction to Kernels¶
In this guide we provide an introduction to kernels, and the role they play in Gaussian process models.
# Enable Float64 for more stable matrix inversions.
from jax import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import install_import_hook, Float
import matplotlib as mpl
import matplotlib.pyplot as plt
import optax as ox
import pandas as pd
from docs.examples.utils import clean_legend
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx
from gpjax.typing import Array
from sklearn.preprocessing import StandardScaler
key = jr.PRNGKey(42)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
Using Gaussian Processes (GPs) to model functions can offer several advantages over alternative methods, such as deep neural networks. One key advantage is their rich quantification of uncertainty; not only do they provide point estimates for the values taken by a function throughout its domain, but they provide a full predictive posterior distribution over the range of values the function may take. This rich quantification of uncertainty is useful in many applications, such as Bayesian optimisation, which relies on being able to make uncertainty-aware decisions.
However, another advantage of GPs is the ability for one to place priors on the functions being modelled. For instance, one may know that the underlying function being modelled observes certain characteristics, such as being periodic or having a certain level of smoothness. The kernel, or covariance function, is the primary means through which one is able to encode such prior knowledge about the function being modelled. This enables one to equip the GP with inductive biases which enable it to learn from data more efficiently, whilst generalising to unseen data more effectively.
In this notebook we'll develop some intuition for what kinds of priors are encoded through the use of different kernels, and how this can be useful when modelling different types of functions.
What is a Kernel?¶
Intuitively, for a function $f$, the kernel defines the notion of similarity between the value of the function at two points, $f(\mathbf{x})$ and $f(\mathbf{x}')$, and will be denoted as $k(\mathbf{x}, \mathbf{x}')$:
$$\begin{aligned} k(\mathbf{x}, \mathbf{x}') &= \text{Cov}[f(\mathbf{x}), f(\mathbf{x}')] \\ &= \mathbb{E}[(f(\mathbf{x}) - \mathbb{E}[f(\mathbf{x})])(f(\mathbf{x}') - \mathbb{E}[f(\mathbf{x}')])] \end{aligned}$$
One would expect that, given a previously unobserved test point $\mathbf{x}^*$, the training points which are closest to this unobserved point will be most similar to it. As such, the kernel is used to define this notion of similarity within the GP framework. It is up to the user to select a kernel function which is appropriate for the function being modelled. In this notebook we are going to give some examples of commonly used kernels, and try to develop an understanding of when one may wish to use one kernel over another. However, before we do this, it is worth discussing the necessary conditions for a function to be a valid kernel/covariance function. This requires a little bit of maths, so for those of you who just wish to obtain an intuitive understanding, feel free to skip to the section introducing the Matérn family of kernels.
What are the necessary conditions for a function to be a valid kernel?¶
Whilst intuitively the kernel function is used to define the notion of similarity within the GP framework, it is important to note that there are two necessary conditions that a kernel function must satisfy in order to be a valid covariance function. For clarity, we will refer to any function mapping two inputs to a scalar output as a kernel function, and we will refer to a valid kernel function satisfying the two necessary conditions as a covariance function. However, it is worth noting that the GP community often uses the terms kernel function and covariance function interchangeably.
The first necessary condition is that the covariance function must be symmetric, i.e. $k(\mathbf{x}, \mathbf{x}') = k(\mathbf{x}', \mathbf{x})$. This is because the covariance between two random variables $X$ and $X'$ is symmetric; if one looks at the definition of covariance given above, it is clear that it is invariant to swapping the order of the inputs $\mathbf{x}$ and $\mathbf{x}'$.
The second necessary condition is that the covariance function must be positive semi-definite (PSD). In order to understand this condition, it is useful to first introduce the concept of a Gram matrix. We'll use the same notation as the GP introduction notebook, and denote $n$ input points as $\mathbf{X} = \{\mathbf{x}_1, \ldots, \mathbf{x}_n\}$. Given these input points and a kernel function $k$ the Gram matrix stores the pairwise kernel evaluations between all input points. Mathematically, this leads to the Gram matrix being defined as:
$$K(\mathbf{X}, \mathbf{X}) = \begin{bmatrix} k(\mathbf{x}_1, \mathbf{x}_1) & \cdots & k(\mathbf{x}_1, \mathbf{x}_n) \\ \vdots & \ddots & \vdots \\ k(\mathbf{x}_n, \mathbf{x}_1) & \cdots & k(\mathbf{x}_n, \mathbf{x}_n) \end{bmatrix}$$
such that $K(\mathbf{X}, \mathbf{X})_{ij} = k(\mathbf{x}_i, \mathbf{x}_j)$.
In order for $k$ to be a valid covariance function, the corresponding Gram matrix must be positive semi-definite. In this case the Gram matrix is referred to as a covariance matrix. A real $n \times n$ matrix $K$ is positive semi-definite if and only if for all vectors $\mathbf{z} \in \mathbb{R}^n$:
$$\mathbf{z}^\top K \mathbf{z} \geq 0$$
Alternatively, a real $n \times n$ matrix $K$ is positive semi-definite if and only if all of its eigenvalues are non-negative.
Therefore, the two necessary conditions for a function to be a valid covariance function are that it must be symmetric and positive semi-definite. In this section we have referred to any function from two inputs to a scalar output as a kernel function, with its corresponding matrix of pairwise evaluations referred to as the Gram matrix, and a function satisfying the two necessary conditions as a covariance function, with its corresponding matrix of pairwise evaluations referred to as the covariance matrix. This enabled us to easily define the necessary conditions for a function to be a valid covariance function. However, as noted previously, the GP community often uses these terms interchangeably, and so we will for the remainder of this notebook.
Introducing a Common Family of Kernels - The Matérn Family¶
One of the most widely used families of kernels is the Matérn family (Matérn, 1960). These kernels take on the following form:
$$k_{\nu}(\mathbf{x}, \mathbf{x'}) = \sigma^2 \frac{2^{1 - \nu}}{\Gamma(\nu)}\left(\sqrt{2\nu} \frac{|\mathbf{x} - \mathbf{x'}|}{\kappa}\right)^{\nu} K_{\nu} \left(\sqrt{2\nu} \frac{|\mathbf{x} - \mathbf{x'}|}{\kappa}\right)$$
where $K_{\nu}$ is a modified Bessel function, $\nu$, $\kappa$ and $\sigma^2$ are hyperparameters specifying the mean-square differentiability, lengthscale and variability respectively, and $|\cdot|$ is used to denote the Euclidean norm. Note that for those of you less interested in the mathematical underpinnings of kernels, it isn't necessary to understand the exact functional form of the Matérn kernels to gain an understanding of how they behave. The key takeaway is that they are parameterised by several hyperparameters, and that these hyperparameters dictate the behaviour of functions sampled from the corresponding GP. The plots below will provide some more intuition for how these hyperparameters affect the behaviour of functions sampled from the corresponding GP.
Some commonly used Matérn kernels use half-integer values of $\nu$, such as $\nu = 1/2$ or $\nu = 3/2$. The fraction is sometimes omitted when naming the kernel, so that $\nu = 1/2$ is referred to as the Matérn12 kernel, and $\nu = 3/2$ is referred to as the Matérn32 kernel. When $\nu$ takes in a half-integer value, $\nu = k + 1/2$, the kernel can be expressed as the product of a polynomial of order $k$ and an exponential:
$$k_{k + 1/2}(\mathbf{x}, \mathbf{x'}) = \sigma^2 \exp\left(-\frac{\sqrt{2\nu}|\mathbf{x} - \mathbf{x'}|}{\kappa}\right) \frac{\Gamma(k+1)}{\Gamma(2k+1)} \times \sum_{i= 0}^k \frac{(k+i)!}{i!(k-i)!} \left(\frac{(\sqrt{8\nu}|\mathbf{x} - \mathbf{x'}|)}{\kappa}\right)^{k-i}$$
In the limit of $\nu \to \infty$ this yields the squared-exponential, or radial basis function (RBF), kernel, which is infinitely mean-square differentiable:
$$k_{\infty}(\mathbf{x}, \mathbf{x'}) = \sigma^2 \exp\left(-\frac{|\mathbf{x} - \mathbf{x'}|^2}{2\kappa^2}\right)$$
But what kind of functions does this kernel encode prior knowledge about? Let's take a look at some samples from GP priors defined used Matérn kernels with different values of $\nu$:
kernels = [
gpx.kernels.Matern12(),
gpx.kernels.Matern32(),
gpx.kernels.Matern52(),
gpx.kernels.RBF(),
]
fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(7, 6), tight_layout=True)
x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1)
meanf = gpx.mean_functions.Zero()
for k, ax in zip(kernels, axes.ravel()):
prior = gpx.gps.Prior(mean_function=meanf, kernel=k)
rv = prior(x)
y = rv.sample(seed=key, sample_shape=(10,))
ax.plot(x, y.T, alpha=0.7)
ax.set_title(k.name)