Kernel Guide¶
In this guide, we introduce the kernels available in GPJax and demonstrate how to create custom kernels.
In [1]:
Copied!
# Enable Float64 for more stable matrix inversions.
from jax import config
config.update("jax_enable_x64", True)
from dataclasses import dataclass
from typing import Dict
from jax import jit
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import (
Array,
Float,
install_import_hook,
)
import matplotlib.pyplot as plt
import numpy as np
from simple_pytree import static_field
import tensorflow_probability.substrates.jax as tfp
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx
from gpjax.base.param import param_field
key = jr.PRNGKey(123)
tfb = tfp.bijectors
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = plt.rcParams["axes.prop_cycle"].by_key()["color"]
# Enable Float64 for more stable matrix inversions.
from jax import config
config.update("jax_enable_x64", True)
from dataclasses import dataclass
from typing import Dict
from jax import jit
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import (
Array,
Float,
install_import_hook,
)
import matplotlib.pyplot as plt
import numpy as np
from simple_pytree import static_field
import tensorflow_probability.substrates.jax as tfp
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx
from gpjax.base.param import param_field
key = jr.PRNGKey(123)
tfb = tfp.bijectors
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = plt.rcParams["axes.prop_cycle"].by_key()["color"]
Supported Kernels¶
The following kernels are natively supported in GPJax.
- Matérn 1/2, 3/2 and 5/2.
- RBF (or squared exponential).
- Rational quadratic.
- Powered exponential.
- Polynomial.
- White noise
- Linear.
- Polynomial.
- Graph kernels.
While the syntax is consistent, each kernel's type influences the characteristics of the sample paths drawn. We visualise this below with 10 function draws per kernel.
In [2]:
Copied!
kernels = [
gpx.kernels.Matern12(),
gpx.kernels.Matern32(),
gpx.kernels.Matern52(),
gpx.kernels.RBF(),
gpx.kernels.Polynomial(),
gpx.kernels.Polynomial(degree=2),
]
fig, axes = plt.subplots(ncols=3, nrows=2, figsize=(10, 6), tight_layout=True)
x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1)
meanf = gpx.mean_functions.Zero()
for k, ax in zip(kernels, axes.ravel()):
prior = gpx.gps.Prior(mean_function=meanf, kernel=k)
rv = prior(x)
y = rv.sample(seed=key, sample_shape=(10,))
ax.plot(x, y.T, alpha=0.7)
ax.set_title(k.name)
kernels = [
gpx.kernels.Matern12(),
gpx.kernels.Matern32(),
gpx.kernels.Matern52(),
gpx.kernels.RBF(),
gpx.kernels.Polynomial(),
gpx.kernels.Polynomial(degree=2),
]
fig, axes = plt.subplots(ncols=3, nrows=2, figsize=(10, 6), tight_layout=True)
x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1)
meanf = gpx.mean_functions.Zero()
for k, ax in zip(kernels, axes.ravel()):
prior = gpx.gps.Prior(mean_function=meanf, kernel=k)
rv = prior(x)
y = rv.sample(seed=key, sample_shape=(10,))
ax.plot(x, y.T, alpha=0.7)
ax.set_title(k.name)