Skip to content

Give me the code

Introduction to Kernels

# Enable Float64 for more stable matrix inversions.
from jax import config

config.update("jax_enable_x64", True)

import jax.numpy as jnp
import jax.random as jr
from jaxtyping import install_import_hook, Float
import matplotlib as mpl
import matplotlib.pyplot as plt
import optax as ox
import pandas as pd
from docs.examples.utils import clean_legend

with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx
from gpjax.typing import Array
from sklearn.preprocessing import StandardScaler

key = jr.key(42)
plt.style.use(
    "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]

kernels = [
    gpx.kernels.Matern12(),
    gpx.kernels.Matern32(),
    gpx.kernels.Matern52(),
    gpx.kernels.RBF(),
]
fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(7, 6), tight_layout=True)

x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1)

meanf = gpx.mean_functions.Zero()

for k, ax in zip(kernels, axes.ravel()):
    prior = gpx.gps.Prior(mean_function=meanf, kernel=k)
    rv = prior(x)
    y = rv.sample(seed=key, sample_shape=(10,))
    ax.plot(x, y.T, alpha=0.7)
    ax.set_title(k.name)

# Forrester function
def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
    return (6 * x - 2) ** 2 * jnp.sin(12 * x - 4)


n = 13

training_x = jr.uniform(key=key, minval=0, maxval=1, shape=(n,)).reshape(-1, 1)
training_y = forrester(training_x)
D = gpx.Dataset(X=training_x, y=training_y)

test_x = jnp.linspace(0, 1, 100).reshape(-1, 1)
test_y = forrester(test_x)

mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.Matern52(
    lengthscale=jnp.array(0.1)
)  # Initialise our kernel lengthscale to 0.1

prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)

likelihood = gpx.likelihoods.Gaussian(
    num_datapoints=D.n, obs_stddev=jnp.array(1e-3)
)  # Our function is noise-free, so we set the observation noise's standard deviation to a very small value
likelihood = likelihood.replace_trainable(obs_stddev=False)

no_opt_posterior = prior * likelihood

negative_mll = gpx.objectives.ConjugateMLL(negative=True)
negative_mll(no_opt_posterior, train_data=D)

opt_posterior, history = gpx.fit_scipy(
    model=no_opt_posterior,
    objective=negative_mll,
    train_data=D,
)

def plot_ribbon(ax, x, dist, color):
    mean = dist.mean()
    std = dist.stddev()
    ax.plot(x, mean, label="Predictive mean", color=color)
    ax.fill_between(
        x.squeeze(),
        mean - 2 * std,
        mean + 2 * std,
        alpha=0.2,
        label="Two sigma",
        color=color,
    )
    ax.plot(x, mean - 2 * std, linestyle="--", linewidth=1, color=color)
    ax.plot(x, mean + 2 * std, linestyle="--", linewidth=1, color=color)

opt_latent_dist = opt_posterior.predict(test_x, train_data=D)
opt_predictive_dist = opt_posterior.likelihood(opt_latent_dist)

opt_predictive_mean = opt_predictive_dist.mean()
opt_predictive_std = opt_predictive_dist.stddev()

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(5, 6))
ax1.plot(
    test_x, test_y, label="Latent function", color=cols[0], linestyle="--", linewidth=2
)
ax1.plot(training_x, training_y, "x", label="Observations", color="k", zorder=5)
plot_ribbon(ax1, test_x, opt_predictive_dist, color=cols[1])
ax1.set_title("Posterior with Hyperparameter Optimisation")
ax1.legend(loc="center left", bbox_to_anchor=(0.975, 0.5))

no_opt_latent_dist = no_opt_posterior.predict(test_x, train_data=D)
no_opt_predictive_dist = no_opt_posterior.likelihood(no_opt_latent_dist)

ax2.plot(
    test_x, test_y, label="Latent function", color=cols[0], linestyle="--", linewidth=2
)
ax2.plot(training_x, training_y, "x", label="Observations", color="k", zorder=5)
plot_ribbon(ax2, test_x, no_opt_predictive_dist, color=cols[1])
ax2.set_title("Posterior without Hyperparameter Optimisation")
ax2.legend(loc="center left", bbox_to_anchor=(0.975, 0.5))

no_opt_lengthscale = no_opt_posterior.prior.kernel.lengthscale
no_opt_variance = no_opt_posterior.prior.kernel.variance
opt_lengthscale = opt_posterior.prior.kernel.lengthscale
opt_variance = opt_posterior.prior.kernel.variance

print(f"Optimised Lengthscale: {opt_lengthscale} and Variance: {opt_variance}")
print(
    f"Non-Optimised Lengthscale: {no_opt_lengthscale} and Variance: {no_opt_variance}"
)

mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.Periodic()
prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)

x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1)
rv = prior(x)
y = rv.sample(seed=key, sample_shape=(10,))

fig, ax = plt.subplots()
ax.plot(x, y.T, alpha=0.7)
ax.set_title("Samples from the Periodic Kernel")
plt.show()

mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.Linear()
prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)

x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1)
rv = prior(x)
y = rv.sample(seed=key, sample_shape=(10,))

fig, ax = plt.subplots()
ax.plot(x, y.T, alpha=0.7)
ax.set_title("Samples from the Linear Kernel")
plt.show()

kernel_one = gpx.kernels.Linear()
kernel_two = gpx.kernels.Periodic()
sum_kernel = gpx.kernels.SumKernel(kernels=[kernel_one, kernel_two])
mean = gpx.mean_functions.Zero()
prior = gpx.gps.Prior(mean_function=mean, kernel=sum_kernel)

x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1)
rv = prior(x)
y = rv.sample(seed=key, sample_shape=(10,))
fig, ax = plt.subplots()
ax.plot(x, y.T, alpha=0.7)
ax.set_title("Samples from a GP Prior with Kernel = Linear + Periodic")
plt.show()

kernel_one = gpx.kernels.Linear()
kernel_two = gpx.kernels.Periodic()
sum_kernel = gpx.kernels.ProductKernel(kernels=[kernel_one, kernel_two])
mean = gpx.mean_functions.Zero()
prior = gpx.gps.Prior(mean_function=mean, kernel=sum_kernel)

x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1)
rv = prior(x)
y = rv.sample(seed=key, sample_shape=(10,))
fig, ax = plt.subplots()
ax.plot(x, y.T, alpha=0.7)
ax.set_title("Samples from a GP with Kernel = Linear x Periodic")
plt.show()

co2_data = pd.read_csv(
    "https://gml.noaa.gov/webdata/ccgg/trends/co2/co2_mm_mlo.csv", comment="#"
)
co2_data = co2_data.loc[co2_data["decimal date"] < 2022 + 11 / 12]
train_x = co2_data["decimal date"].values[:, None]
train_y = co2_data["average"].values[:, None]

fig, ax = plt.subplots()
ax.plot(train_x, train_y)
ax.set_title("CO2 Concentration in the Atmosphere")
ax.set_xlabel("Year")
ax.set_ylabel("CO2 Concentration (ppm)")
plt.show()

test_x = jnp.linspace(1950, 2030, 5000, dtype=jnp.float64).reshape(-1, 1)
y_scaler = StandardScaler().fit(train_y)
standardised_train_y = y_scaler.transform(train_y)

D = gpx.Dataset(X=train_x, y=standardised_train_y)

mean = gpx.mean_functions.Zero()
rbf_kernel = gpx.kernels.RBF(lengthscale=100.0)
periodic_kernel = gpx.kernels.Periodic()
linear_kernel = gpx.kernels.Linear(variance=0.001)
sum_kernel = gpx.kernels.SumKernel(kernels=[linear_kernel, periodic_kernel])
final_kernel = gpx.kernels.SumKernel(kernels=[rbf_kernel, sum_kernel])

prior = gpx.gps.Prior(mean_function=mean, kernel=final_kernel)
likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)

posterior = prior * likelihood

negative_mll = gpx.objectives.ConjugateMLL(negative=True)
negative_mll(posterior, train_data=D)

opt_posterior, history = gpx.fit(
    model=posterior,
    objective=negative_mll,
    train_data=D,
    optim=ox.adamw(learning_rate=1e-2),
    num_iters=500,
    key=key,
)

latent_dist = opt_posterior.predict(test_x, train_data=D)
predictive_dist = opt_posterior.likelihood(latent_dist)

predictive_mean = predictive_dist.mean().reshape(-1, 1)
predictive_std = predictive_dist.stddev().reshape(-1, 1)

fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(
    train_x, standardised_train_y, "x", label="Observations", color=cols[0], alpha=0.5
)
ax.fill_between(
    test_x.squeeze(),
    predictive_mean.squeeze() - 2 * predictive_std.squeeze(),
    predictive_mean.squeeze() + 2 * predictive_std.squeeze(),
    alpha=0.2,
    label="Two sigma",
    color=cols[1],
)
ax.plot(
    test_x,
    predictive_mean - 2 * predictive_std,
    linestyle="--",
    linewidth=1,
    color=cols[1],
)
ax.plot(
    test_x,
    predictive_mean + 2 * predictive_std,
    linestyle="--",
    linewidth=1,
    color=cols[1],
)
ax.plot(test_x, predictive_mean, label="Predictive mean", color=cols[1])
ax.set_xlabel("Year")
ax.legend(loc="center left", bbox_to_anchor=(0.975, 0.5))

fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(
    train_x[train_x >= 2010],
    standardised_train_y[train_x >= 2010],
    "x",
    label="Observations",
    color=cols[0],
    alpha=0.5,
)
ax.fill_between(
    test_x[test_x >= 2010].squeeze(),
    predictive_mean[test_x >= 2010] - 2 * predictive_std[test_x >= 2010],
    predictive_mean[test_x >= 2010] + 2 * predictive_std[test_x >= 2010],
    alpha=0.2,
    label="Two sigma",
    color=cols[1],
)
ax.plot(
    test_x[test_x >= 2010],
    predictive_mean[test_x >= 2010] - 2 * predictive_std[test_x >= 2010],
    linestyle="--",
    linewidth=1,
    color=cols[1],
)
ax.plot(
    test_x[test_x >= 2010],
    predictive_mean[test_x >= 2010] + 2 * predictive_std[test_x >= 2010],
    linestyle="--",
    linewidth=1,
    color=cols[1],
)
ax.plot(
    test_x[test_x >= 2010],
    predictive_mean[test_x >= 2010],
    label="Predictive mean",
    color=cols[1],
)
ax.set_xlabel("Year")
ax.legend(loc="center left", bbox_to_anchor=(0.975, 0.5))

print(
    "Periodic Kernel Period:"
    f" {[i for i in opt_posterior.prior.kernel.kernels if isinstance(i, gpx.kernels.Periodic)][0].period}"
)

%reload_ext watermark
%watermark -n -u -v -iv -w -a 'Thomas Christie'

Classification

# Enable Float64 for more stable matrix inversions.
from jax import config

config.update("jax_enable_x64", True)

from time import time
import blackjax
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import jax.tree_util as jtu
from jaxtyping import (
    Array,
    Float,
    install_import_hook,
)
import matplotlib.pyplot as plt
import optax as ox
import tensorflow_probability.substrates.jax as tfp
from tqdm import trange

with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx

tfd = tfp.distributions
identity_matrix = jnp.eye
key = jr.key(123)
plt.style.use(
    "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = plt.rcParams["axes.prop_cycle"].by_key()["color"]

key, subkey = jr.split(key)
x = jr.uniform(key, shape=(100, 1), minval=-1.0, maxval=1.0)
y = 0.5 * jnp.sign(jnp.cos(3 * x + jr.normal(subkey, shape=x.shape) * 0.05)) + 0.5

D = gpx.Dataset(X=x, y=y)

xtest = jnp.linspace(-1.0, 1.0, 500).reshape(-1, 1)

fig, ax = plt.subplots()
ax.scatter(x, y)

kernel = gpx.kernels.RBF()
meanf = gpx.mean_functions.Constant()
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
likelihood = gpx.likelihoods.Bernoulli(num_datapoints=D.n)

posterior = prior * likelihood
print(type(posterior))

negative_lpd = jax.jit(gpx.objectives.LogPosteriorDensity(negative=True))

optimiser = ox.adam(learning_rate=0.01)

opt_posterior, history = gpx.fit(
    model=posterior,
    objective=negative_lpd,
    train_data=D,
    optim=ox.adamw(learning_rate=0.01),
    num_iters=1000,
    key=key,
)

map_latent_dist = opt_posterior.predict(xtest, train_data=D)
predictive_dist = opt_posterior.likelihood(map_latent_dist)

predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()

fig, ax = plt.subplots()
ax.scatter(x, y, label="Observations", color=cols[0])
ax.plot(xtest, predictive_mean, label="Predictive mean", color=cols[1])
ax.fill_between(
    xtest.squeeze(),
    predictive_mean - predictive_std,
    predictive_mean + predictive_std,
    alpha=0.2,
    color=cols[1],
    label="One sigma",
)
ax.plot(
    xtest,
    predictive_mean - predictive_std,
    color=cols[1],
    linestyle="--",
    linewidth=1,
)
ax.plot(
    xtest,
    predictive_mean + predictive_std,
    color=cols[1],
    linestyle="--",
    linewidth=1,
)

ax.legend()

import cola
from gpjax.lower_cholesky import lower_cholesky

gram, cross_covariance = (kernel.gram, kernel.cross_covariance)
jitter = 1e-6

# Compute (latent) function value map estimates at training points:
Kxx = opt_posterior.prior.kernel.gram(x)
Kxx += identity_matrix(D.n) * jitter
Kxx = cola.PSD(Kxx)
Lx = lower_cholesky(Kxx)
f_hat = Lx @ opt_posterior.latent

# Negative Hessian,  H = -βˆ‡Β²p_tilde(y|f):
H = jax.jacfwd(jax.jacrev(negative_lpd))(opt_posterior, D).latent.latent[:, 0, :, 0]

L = jnp.linalg.cholesky(H + identity_matrix(D.n) * jitter)

# H⁻¹ = H⁻¹ I = (LLα΅€)⁻¹ I = L⁻ᡀL⁻¹ I
L_inv = jsp.linalg.solve_triangular(L, identity_matrix(D.n), lower=True)
H_inv = jsp.linalg.solve_triangular(L.T, L_inv, lower=False)
LH = jnp.linalg.cholesky(H_inv)
laplace_approximation = tfd.MultivariateNormalTriL(f_hat.squeeze(), LH)

def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNormalTriL:
    map_latent_dist = opt_posterior.predict(xtest, train_data=D)

    Kxt = opt_posterior.prior.kernel.cross_covariance(x, test_inputs)
    Kxx = opt_posterior.prior.kernel.gram(x)
    Kxx += identity_matrix(D.n) * jitter
    Kxx = cola.PSD(Kxx)

    # Kxx⁻¹ Kxt
    Kxx_inv_Kxt = cola.solve(Kxx, Kxt)

    # Ktx Kxx⁻¹[ H⁻¹ ] Kxx⁻¹ Kxt
    laplace_cov_term = jnp.matmul(jnp.matmul(Kxx_inv_Kxt.T, H_inv), Kxx_inv_Kxt)

    mean = map_latent_dist.mean()
    covariance = map_latent_dist.covariance() + laplace_cov_term
    L = jnp.linalg.cholesky(covariance)
    return tfd.MultivariateNormalTriL(jnp.atleast_1d(mean.squeeze()), L)

laplace_latent_dist = construct_laplace(xtest)
predictive_dist = opt_posterior.likelihood(laplace_latent_dist)

predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()

fig, ax = plt.subplots()
ax.scatter(x, y, label="Observations", color=cols[0])
ax.plot(xtest, predictive_mean, label="Predictive mean", color=cols[1])
ax.fill_between(
    xtest.squeeze(),
    predictive_mean - predictive_std,
    predictive_mean + predictive_std,
    alpha=0.2,
    color=cols[1],
    label="One sigma",
)
ax.plot(
    xtest,
    predictive_mean - predictive_std,
    color=cols[1],
    linestyle="--",
    linewidth=1,
)
ax.plot(
    xtest,
    predictive_mean + predictive_std,
    color=cols[1],
    linestyle="--",
    linewidth=1,
)
ax.legend()

num_adapt = 500
num_samples = 500

lpd = jax.jit(gpx.objectives.LogPosteriorDensity(negative=False))
unconstrained_lpd = jax.jit(lambda tree: lpd(tree.constrain(), D))

adapt = blackjax.window_adaptation(
    blackjax.nuts, unconstrained_lpd, num_adapt, target_acceptance_rate=0.65
)

# Initialise the chain
start = time()
last_state, kernel, _ = adapt.run(key, posterior.unconstrain())
print(f"Adaption time taken: {time() - start: .1f} seconds")


def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, info = kernel(rng_key, state)
        return state, (state, info)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

    return states, infos


# Sample from the posterior distribution
start = time()
states, infos = inference_loop(key, kernel, last_state, num_samples)
print(f"Sampling time taken: {time() - start: .1f} seconds")

acceptance_rate = jnp.mean(infos.acceptance_probability)
print(f"Acceptance rate: {acceptance_rate:.2f}")

fig, (ax0, ax1, ax2) = plt.subplots(ncols=3, figsize=(10, 3))
ax0.plot(states.position.prior.kernel.lengthscale)
ax1.plot(states.position.prior.kernel.variance)
ax2.plot(states.position.latent[:, 1, :])
ax0.set_title("Kernel Lengthscale")
ax1.set_title("Kernel Variance")
ax2.set_title("Latent Function (index = 1)")

thin_factor = 20
posterior_samples = []

for i in trange(0, num_samples, thin_factor, desc="Drawing posterior samples"):
    sample = jtu.tree_map(lambda samples, i=i: samples[i], states.position)
    sample = sample.constrain()
    latent_dist = sample.predict(xtest, train_data=D)
    predictive_dist = sample.likelihood(latent_dist)
    posterior_samples.append(predictive_dist.sample(seed=key, sample_shape=(10,)))

posterior_samples = jnp.vstack(posterior_samples)
lower_ci, upper_ci = jnp.percentile(posterior_samples, jnp.array([2.5, 97.5]), axis=0)
expected_val = jnp.mean(posterior_samples, axis=0)

fig, ax = plt.subplots()
ax.scatter(x, y, color=cols[0], label="Observations", zorder=2, alpha=0.7)
ax.plot(xtest, expected_val, color=cols[1], label="Predicted mean", zorder=1)
ax.fill_between(
    xtest.flatten(),
    lower_ci.flatten(),
    upper_ci.flatten(),
    alpha=0.2,
    color=cols[1],
    label="95\\% CI",
)
ax.plot(
    xtest,
    lower_ci.flatten(),
    color=cols[1],
    linestyle="--",
    linewidth=1,
)
ax.plot(
    xtest,
    upper_ci.flatten(),
    color=cols[1],
    linestyle="--",
    linewidth=1,
)
ax.legend()

%load_ext watermark
%watermark -n -u -v -iv -w -a "Thomas Pinder & Daniel Dodd"

Gaussian Processes Barycentres

# Enable Float64 for more stable matrix inversions.
from jax import config

config.update("jax_enable_x64", True)

import typing as tp

import jax
import jax.numpy as jnp
import jax.random as jr
import jax.scipy.linalg as jsl
from jaxtyping import install_import_hook
import matplotlib.pyplot as plt
import optax as ox
import tensorflow_probability.substrates.jax.distributions as tfd

with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx


key = jr.key(123)
plt.style.use(
    "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = plt.rcParams["axes.prop_cycle"].by_key()["color"]

n = 100
n_test = 200
n_datasets = 5

x = jnp.linspace(-5.0, 5.0, n).reshape(-1, 1)
xtest = jnp.linspace(-5.5, 5.5, n_test).reshape(-1, 1)
f = lambda x, a, b: a + jnp.sin(b * x)

ys = []
for _i in range(n_datasets):
    key, subkey = jr.split(key)
    vertical_shift = jr.uniform(subkey, minval=0.0, maxval=2.0)
    period = jr.uniform(subkey, minval=0.75, maxval=1.25)
    noise_amount = jr.uniform(subkey, minval=0.01, maxval=0.5)
    noise = jr.normal(subkey, shape=x.shape) * noise_amount
    ys.append(f(x, vertical_shift, period) + noise)

y = jnp.hstack(ys)

fig, ax = plt.subplots()
ax.plot(x, y, "x")
plt.show()

def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance:
    if y.ndim == 1:
        y = y.reshape(-1, 1)
    D = gpx.Dataset(X=x, y=y)

    likelihood = gpx.likelihoods.Gaussian(num_datapoints=n)
    posterior = (
        gpx.gps.Prior(
            mean_function=gpx.mean_functions.Constant(), kernel=gpx.kernels.RBF()
        )
        * likelihood
    )

    opt_posterior, _ = gpx.fit_scipy(
        model=posterior,
        objective=gpx.objectives.ConjugateMLL(negative=True),
        train_data=D,
    )
    latent_dist = opt_posterior.predict(xtest, train_data=D)
    return opt_posterior.likelihood(latent_dist)


posterior_preds = [fit_gp(x, i) for i in ys]

def sqrtm(A: jax.Array):
    return jnp.real(jsl.sqrtm(A))


def wasserstein_barycentres(
    distributions: tp.List[tfd.MultivariateNormalFullCovariance], weights: jax.Array
):
    covariances = [d.covariance() for d in distributions]
    cov_stack = jnp.stack(covariances)
    stack_sqrt = jax.vmap(sqrtm)(cov_stack)

    def step(covariance_candidate: jax.Array, idx: None):
        inner_term = jax.vmap(sqrtm)(
            jnp.matmul(jnp.matmul(stack_sqrt, covariance_candidate), stack_sqrt)
        )
        fixed_point = jnp.tensordot(weights, inner_term, axes=1)
        return fixed_point, fixed_point

    return step

weights = jnp.ones((n_datasets,)) / n_datasets

means = jnp.stack([d.mean() for d in posterior_preds])
barycentre_mean = jnp.tensordot(weights, means, axes=1)

step_fn = jax.jit(wasserstein_barycentres(posterior_preds, weights))
initial_covariance = jnp.eye(n_test)

barycentre_covariance, sequence = jax.lax.scan(
    step_fn, initial_covariance, jnp.arange(100)
)
L = jnp.linalg.cholesky(barycentre_covariance)

barycentre_process = tfd.MultivariateNormalTriL(barycentre_mean, L)

def plot(
    dist: tfd.MultivariateNormalTriL,
    ax,
    color: str,
    label: str = None,
    ci_alpha: float = 0.2,
    linewidth: float = 1.0,
    zorder: int = 0,
):
    mu = dist.mean()
    sigma = dist.stddev()
    ax.plot(xtest, mu, linewidth=linewidth, color=color, label=label, zorder=zorder)
    ax.fill_between(
        xtest.squeeze(),
        mu - sigma,
        mu + sigma,
        alpha=ci_alpha,
        color=color,
        zorder=zorder,
    )


fig, ax = plt.subplots()
[plot(d, ax, color=cols[1], ci_alpha=0.1) for d in posterior_preds]
plot(
    barycentre_process,
    ax,
    color=cols[0],
    label="Barycentre",
    ci_alpha=0.5,
    linewidth=2,
    zorder=1,
)
ax.legend()

%reload_ext watermark
%watermark -n -u -v -iv -w -a 'Thomas Pinder'

Deep Kernel Learning

# Enable Float64 for more stable matrix inversions.
from jax import config

config.update("jax_enable_x64", True)

from dataclasses import (
    dataclass,
    field,
)
from typing import Any

import flax
from flax import linen as nn
import jax
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import (
    Array,
    Float,
    install_import_hook,
)
import matplotlib as mpl
import matplotlib.pyplot as plt
import optax as ox
from scipy.signal import sawtooth
from gpjax.base import static_field

with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx
    from gpjax.base import param_field
    import gpjax.kernels as jk
    from gpjax.kernels import DenseKernelComputation
    from gpjax.kernels.base import AbstractKernel
    from gpjax.kernels.computations import AbstractKernelComputation

key = jr.key(123)
plt.style.use(
    "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]

n = 500
noise = 0.2

key, subkey = jr.split(key)
x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(n,)).reshape(-1, 1)
f = lambda x: jnp.asarray(sawtooth(2 * jnp.pi * x))
signal = f(x)
y = signal + jr.normal(subkey, shape=signal.shape) * noise

D = gpx.Dataset(X=x, y=y)

xtest = jnp.linspace(-2.0, 2.0, 500).reshape(-1, 1)
ytest = f(xtest)

fig, ax = plt.subplots()
ax.plot(x, y, "o", label="Training data", alpha=0.5)
ax.plot(xtest, ytest, label="True function")
ax.legend(loc="best")

@dataclass
class DeepKernelFunction(AbstractKernel):
    base_kernel: AbstractKernel = None
    network: nn.Module = static_field(None)
    dummy_x: jax.Array = static_field(None)
    key: jax.Array = static_field(jr.key(123))
    nn_params: Any = field(init=False, repr=False)

    def __post_init__(self):
        if self.base_kernel is None:
            raise ValueError("base_kernel must be specified")
        if self.network is None:
            raise ValueError("network must be specified")
        self.nn_params = flax.core.unfreeze(self.network.init(key, self.dummy_x))

    def __call__(
        self, x: Float[Array, " D"], y: Float[Array, " D"]
    ) -> Float[Array, "1"]:
        state = self.network.init(self.key, x)
        xt = self.network.apply(state, x)
        yt = self.network.apply(state, y)
        return self.base_kernel(xt, yt)

feature_space_dim = 3


class Network(nn.Module):
    """A simple MLP."""

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=32)(x)
        x = nn.relu(x)
        x = nn.Dense(features=64)(x)
        x = nn.relu(x)
        x = nn.Dense(features=feature_space_dim)(x)
        return x


forward_linear = Network()

base_kernel = gpx.kernels.Matern52(
    active_dims=list(range(feature_space_dim)),
    lengthscale=jnp.ones((feature_space_dim,)),
)
kernel = DeepKernelFunction(
    network=forward_linear, base_kernel=base_kernel, key=key, dummy_x=x
)
meanf = gpx.mean_functions.Zero()
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
posterior = prior * likelihood

schedule = ox.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=0.01,
    warmup_steps=75,
    decay_steps=700,
    end_value=0.0,
)

optimiser = ox.chain(
    ox.clip(1.0),
    ox.adamw(learning_rate=schedule),
)

opt_posterior, history = gpx.fit(
    model=posterior,
    objective=jax.jit(gpx.objectives.ConjugateMLL(negative=True)),
    train_data=D,
    optim=optimiser,
    num_iters=800,
    key=key,
)

latent_dist = opt_posterior(xtest, train_data=D)
predictive_dist = opt_posterior.likelihood(latent_dist)

predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()

fig, ax = plt.subplots()
ax.plot(x, y, "o", label="Observations", color=cols[0])
ax.plot(xtest, predictive_mean, label="Predictive mean", color=cols[1])
ax.fill_between(
    xtest.squeeze(),
    predictive_mean - 2 * predictive_std,
    predictive_mean + 2 * predictive_std,
    alpha=0.2,
    color=cols[1],
    label="Two sigma",
)
ax.plot(
    xtest,
    predictive_mean - 2 * predictive_std,
    color=cols[1],
    linestyle="--",
    linewidth=1,
)
ax.plot(
    xtest,
    predictive_mean + 2 * predictive_std,
    color=cols[1],
    linestyle="--",
    linewidth=1,
)
ax.legend()

%reload_ext watermark
%watermark -n -u -v -iv -w -a 'Thomas Pinder'

Sparse Stochastic Variational Inference

# Enable Float64 for more stable matrix inversions.
from jax import config

config.update("jax_enable_x64", True)

from jax import jit
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import install_import_hook
import matplotlib as mpl
import matplotlib.pyplot as plt
import optax as ox
import tensorflow_probability.substrates.jax as tfp

with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx
    import gpjax.kernels as jk

key = jr.key(123)
tfb = tfp.bijectors
plt.style.use(
    "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]

n = 50000
noise = 0.2

key, subkey = jr.split(key)
x = jr.uniform(key=key, minval=-5.0, maxval=5.0, shape=(n,)).reshape(-1, 1)
f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)
signal = f(x)
y = signal + jr.normal(subkey, shape=signal.shape) * noise
D = gpx.Dataset(X=x, y=y)

xtest = jnp.linspace(-5.5, 5.5, 500).reshape(-1, 1)

z = jnp.linspace(-5.0, 5.0, 50).reshape(-1, 1)

fig, ax = plt.subplots()
ax.vlines(
    z,
    ymin=y.min(),
    ymax=y.max(),
    alpha=0.3,
    linewidth=1,
    label="Inducing point",
    color=cols[2],
)
ax.scatter(x, y, alpha=0.2, color=cols[0], label="Observations")
ax.plot(xtest, f(xtest), color=cols[1], label="Latent function")
ax.legend()
ax.set(xlabel=r"$x$", ylabel=r"$f(x)$")

meanf = gpx.mean_functions.Zero()
likelihood = gpx.likelihoods.Gaussian(num_datapoints=n)
prior = gpx.gps.Prior(mean_function=meanf, kernel=jk.RBF())
p = prior * likelihood
q = gpx.variational_families.VariationalGaussian(posterior=p, inducing_inputs=z)

negative_elbo = gpx.objectives.ELBO(negative=True)

print(gpx.cite(negative_elbo))

negative_elbo = jit(negative_elbo)

schedule = ox.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=0.01,
    warmup_steps=75,
    decay_steps=1500,
    end_value=0.001,
)

opt_posterior, history = gpx.fit(
    model=q,
    objective=negative_elbo,
    train_data=D,
    optim=ox.adam(learning_rate=schedule),
    num_iters=3000,
    key=jr.key(42),
    batch_size=128,
)

latent_dist = opt_posterior(xtest)
predictive_dist = opt_posterior.posterior.likelihood(latent_dist)

meanf = predictive_dist.mean()
sigma = predictive_dist.stddev()

fig, ax = plt.subplots()
ax.scatter(x, y, alpha=0.15, label="Training Data", color=cols[0])
ax.plot(xtest, meanf, label="Posterior mean", color=cols[1])
ax.fill_between(
    xtest.flatten(),
    meanf - 2 * sigma,
    meanf + 2 * sigma,
    alpha=0.3,
    color=cols[1],
    label="Two sigma",
)
ax.vlines(
    opt_posterior.inducing_inputs,
    ymin=y.min(),
    ymax=y.max(),
    alpha=0.3,
    linewidth=1,
    label="Inducing point",
    color=cols[2],
)
ax.legend()

triangular_transform = tfb.FillScaleTriL(
    diag_bijector=tfb.Square(), diag_shift=jnp.array(q.jitter)
)
reparameterised_q = q.replace_bijector(variational_root_covariance=triangular_transform)

opt_rep, history = gpx.fit(
    model=reparameterised_q,
    objective=negative_elbo,
    train_data=D,
    optim=ox.adam(learning_rate=0.01),
    num_iters=3000,
    key=jr.key(42),
    batch_size=128,
)

latent_dist = opt_rep(xtest)
predictive_dist = opt_rep.posterior.likelihood(latent_dist)

meanf = predictive_dist.mean()
sigma = predictive_dist.stddev()

fig, ax = plt.subplots()
ax.scatter(x, y, alpha=0.15, label="Training Data", color=cols[0])
ax.plot(xtest, meanf, label="Posterior mean", color=cols[1])
ax.fill_between(
    xtest.flatten(),
    meanf - 2 * sigma,
    meanf + 2 * sigma,
    alpha=0.3,
    color=cols[1],
    label="Two sigma",
)
ax.vlines(
    opt_rep.inducing_inputs,
    ymin=y.min(),
    ymax=y.max(),
    alpha=0.3,
    linewidth=1,
    label="Inducing point",
    color=cols[2],
)
ax.legend()

%reload_ext watermark
%watermark -n -u -v -iv -w -a 'Thomas Pinder, Daniel Dodd & Zeel B Patel'

New to Gaussian Processes?

import warnings

import jax.numpy as jnp
import jax.random as jr
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import tensorflow_probability.substrates.jax as tfp
from docs.examples.utils import confidence_ellipse

plt.style.use(
    "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
tfd = tfp.distributions

ud1 = tfd.Normal(0.0, 1.0)
ud2 = tfd.Normal(-1.0, 0.5)
ud3 = tfd.Normal(0.25, 1.5)

xs = jnp.linspace(-5.0, 5.0, 500)

fig, ax = plt.subplots()
for d in [ud1, ud2, ud3]:
    ax.plot(
        xs,
        jnp.exp(d.log_prob(xs)),
        label=f"$\\mathcal{{N}}({{{float(d.mean())}}},\\  {{{float(d.stddev())}}}^2)$",
    )
    ax.fill_between(xs, jnp.zeros_like(xs), jnp.exp(d.log_prob(xs)), alpha=0.2)
ax.legend(loc="best")

key = jr.key(123)

d1 = tfd.MultivariateNormalDiag(loc=jnp.zeros(2), scale_diag=jnp.ones(2))
d2 = tfd.MultivariateNormalTriL(
    jnp.zeros(2), jnp.linalg.cholesky(jnp.array([[1.0, 0.9], [0.9, 1.0]]))
)
d3 = tfd.MultivariateNormalTriL(
    jnp.zeros(2), jnp.linalg.cholesky(jnp.array([[1.0, -0.5], [-0.5, 1.0]]))
)

dists = [d1, d2, d3]

xvals = jnp.linspace(-5.0, 5.0, 500)
yvals = jnp.linspace(-5.0, 5.0, 500)

xx, yy = jnp.meshgrid(xvals, yvals)

pos = jnp.empty(xx.shape + (2,))
pos.at[:, :, 0].set(xx)
pos.at[:, :, 1].set(yy)

fig, (ax0, ax1, ax2) = plt.subplots(figsize=(10, 3), ncols=3, tight_layout=True)
titles = [r"$\rho = 0$", r"$\rho = 0.9$", r"$\rho = -0.5$"]

cmap = mpl.colors.LinearSegmentedColormap.from_list("custom", ["white", cols[1]], N=256)

for a, t, d in zip([ax0, ax1, ax2], titles, dists):
    d_prob = d.prob(jnp.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)])).reshape(
        xx.shape
    )
    cntf = a.contourf(xx, yy, jnp.exp(d_prob), levels=20, antialiased=True, cmap=cmap)
    for c in cntf.collections:
        c.set_edgecolor("face")
    a.set_xlim(-2.75, 2.75)
    a.set_ylim(-2.75, 2.75)
    samples = d.sample(seed=key, sample_shape=(5000,))
    xsample, ysample = samples[:, 0], samples[:, 1]
    confidence_ellipse(
        xsample, ysample, a, edgecolor="#3f3f3f", n_std=1.0, linestyle="--", alpha=0.8
    )
    confidence_ellipse(
        xsample, ysample, a, edgecolor="#3f3f3f", n_std=2.0, linestyle="--"
    )
    a.plot(0, 0, "x", color=cols[0], markersize=8, mew=2)
    a.set(xlabel="x", ylabel="y", title=t)

n = 1000
x = tfd.Normal(loc=0.0, scale=1.0).sample(seed=key, sample_shape=(n,))
key, subkey = jr.split(key)
y = tfd.Normal(loc=0.25, scale=0.5).sample(seed=subkey, sample_shape=(n,))
key, subkey = jr.split(subkey)
xfull = tfd.Normal(loc=0.0, scale=1.0).sample(seed=subkey, sample_shape=(n * 10,))
key, subkey = jr.split(subkey)
yfull = tfd.Normal(loc=0.25, scale=0.5).sample(seed=subkey, sample_shape=(n * 10,))
key, subkey = jr.split(subkey)
df = pd.DataFrame({"x": x, "y": y, "idx": jnp.ones(n)})

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    g = sns.jointplot(
        data=df,
        x="x",
        y="y",
        hue="idx",
        marker=".",
        space=0.0,
        xlim=(-4.0, 4.0),
        ylim=(-4.0, 4.0),
        height=4,
        marginal_ticks=False,
        legend=False,
        palette="inferno",
        marginal_kws={
            "fill": True,
            "linewidth": 1,
            "color": cols[1],
            "alpha": 0.3,
            "bw_adjust": 2,
            "cmap": cmap,
        },
        joint_kws={"color": cols[1], "size": 3.5, "alpha": 0.4, "cmap": cmap},
    )
    g.ax_joint.annotate(text=r"$p(\mathbf{x}, \mathbf{y})$", xy=(-3, -1.75))
    g.ax_marg_x.annotate(text=r"$p(\mathbf{x})$", xy=(-2.0, 0.225))
    g.ax_marg_y.annotate(text=r"$p(\mathbf{y})$", xy=(0.4, -0.78))
    confidence_ellipse(
        xfull,
        yfull,
        g.ax_joint,
        edgecolor="#3f3f3f",
        n_std=1.0,
        linestyle="--",
        linewidth=0.5,
    )
    confidence_ellipse(
        xfull,
        yfull,
        g.ax_joint,
        edgecolor="#3f3f3f",
        n_std=2.0,
        linestyle="--",
        linewidth=0.5,
    )
    confidence_ellipse(
        xfull,
        yfull,
        g.ax_joint,
        edgecolor="#3f3f3f",
        n_std=3.0,
        linestyle="--",
        linewidth=0.5,
    )

%reload_ext watermark
%watermark -n -u -v -iv -w -a 'Thomas Pinder'

UCI Data Benchmarking

# Enable Float64 for more stable matrix inversions.
from jax import config

config.update("jax_enable_x64", True)

from jax import jit
import jax.random as jr
from jaxtyping import install_import_hook
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import (
    mean_squared_error,
    r2_score,
)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx

# Enable Float64 for more stable matrix inversions.
key = jr.key(123)
plt.style.use(
    "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]

try:
    yacht = pd.read_fwf("data/yacht_hydrodynamics.data", header=None).values[:-1, :]
except FileNotFoundError:
    yacht = pd.read_fwf(
        "docs/examples/data/yacht_hydrodynamics.data", header=None
    ).values[:-1, :]

X = yacht[:, :-1]
y = yacht[:, -1].reshape(-1, 1)

Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.3, random_state=42)

log_ytr = np.log(ytr)
log_yte = np.log(yte)

y_scaler = StandardScaler().fit(log_ytr)
scaled_ytr = y_scaler.transform(log_ytr)
scaled_yte = y_scaler.transform(log_yte)

fig, ax = plt.subplots(ncols=3, figsize=(9, 2.5))
ax[0].hist(ytr, bins=30, color=cols[1])
ax[0].set_title("y")
ax[1].hist(log_ytr, bins=30, color=cols[1])
ax[1].set_title("log(y)")
ax[2].hist(scaled_ytr, bins=30, color=cols[1])
ax[2].set_title("scaled log(y)")

x_scaler = StandardScaler().fit(Xtr)
scaled_Xtr = x_scaler.transform(Xtr)
scaled_Xte = x_scaler.transform(Xte)

n_train, n_covariates = scaled_Xtr.shape
kernel = gpx.kernels.RBF(
    active_dims=list(range(n_covariates)),
    variance=np.var(scaled_ytr),
    lengthscale=0.1 * np.ones((n_covariates,)),
)
meanf = gpx.mean_functions.Zero()
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)

likelihood = gpx.likelihoods.Gaussian(num_datapoints=n_train)

posterior = prior * likelihood

training_data = gpx.Dataset(X=scaled_Xtr, y=scaled_ytr)

negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True))

opt_posterior, history = gpx.fit_scipy(
    model=posterior,
    objective=negative_mll,
    train_data=training_data,
)

latent_dist = opt_posterior(scaled_Xte, training_data)
predictive_dist = likelihood(latent_dist)

predictive_mean = predictive_dist.mean()
predictive_stddev = predictive_dist.stddev()

rmse = mean_squared_error(y_true=scaled_yte.squeeze(), y_pred=predictive_mean)
r2 = r2_score(y_true=scaled_yte.squeeze(), y_pred=predictive_mean)
print(f"Results:\n\tRMSE: {rmse: .4f}\n\tR2: {r2: .2f}")

residuals = scaled_yte.squeeze() - predictive_mean

fig, ax = plt.subplots(ncols=3, figsize=(9, 2.5), tight_layout=True)

ax[0].scatter(predictive_mean, scaled_yte.squeeze(), color=cols[1])
ax[0].plot([0, 1], [0, 1], color=cols[0], transform=ax[0].transAxes)
ax[0].set(xlabel="Predicted", ylabel="Actual", title="Predicted vs Actual")

ax[1].scatter(predictive_mean.squeeze(), residuals, color=cols[1])
ax[1].plot([0, 1], [0.5, 0.5], color=cols[0], transform=ax[1].transAxes)
ax[1].set_ylim([-1.0, 1.0])
ax[1].set(xlabel="Predicted", ylabel="Residuals", title="Predicted vs Residuals")

ax[2].hist(np.asarray(residuals), bins=30, color=cols[1])
ax[2].set_title("Residuals")

%reload_ext watermark
%watermark -n -u -v -iv -w -a 'Thomas Pinder'

Introduction to Decision Making with GPJax

# Enable Float64 for more stable matrix inversions.
from jax import config

config.update("jax_enable_x64", True)


import jax.numpy as jnp
import jax.random as jr
import matplotlib as mpl
import matplotlib.pyplot as plt
import optax as ox

import gpjax as gpx
from gpjax.decision_making.utility_functions import (
    ThompsonSampling,
)
from gpjax.decision_making.utility_maximizer import (
    ContinuousSinglePointUtilityMaximizer,
)
from gpjax.decision_making.decision_maker import UtilityDrivenDecisionMaker
from gpjax.decision_making.utils import (
    OBJECTIVE,
    build_function_evaluator,
)
from gpjax.decision_making.posterior_handler import PosteriorHandler
from gpjax.decision_making.search_space import ContinuousSearchSpace
from gpjax.typing import (
    Array,
    Float,
)

key = jr.key(42)
plt.style.use(
    "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]

def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
    return (6 * x - 2) ** 2 * jnp.sin(12 * x - 4)

function_evaluator = build_function_evaluator({OBJECTIVE: forrester})

lower_bounds = jnp.array([0.0])
upper_bounds = jnp.array([1.0])
search_space = ContinuousSearchSpace(
    lower_bounds=lower_bounds, upper_bounds=upper_bounds
)

initial_x = search_space.sample(5, key)
initial_datasets = function_evaluator(initial_x)

mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.Matern52()
prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)

likelihood_builder = lambda n: gpx.likelihoods.Gaussian(
    num_datapoints=n, obs_stddev=jnp.array(1e-3)
)

posterior_handler = PosteriorHandler(
    prior,
    likelihood_builder=likelihood_builder,
    optimization_objective=gpx.objectives.ConjugateMLL(negative=True),
    optimizer=ox.adam(learning_rate=0.01),
    num_optimization_iters=1000,
)
posterior_handlers = {OBJECTIVE: posterior_handler}

utility_function_builder = ThompsonSampling(num_features=500)

acquisition_maximizer = ContinuousSinglePointUtilityMaximizer(
    num_initial_samples=100, num_restarts=1
)

def plot_bo_iteration(
    dm: UtilityDrivenDecisionMaker, last_queried_points: Float[Array, "B D"]
):
    posterior = dm.posteriors[OBJECTIVE]
    dataset = dm.datasets[OBJECTIVE]
    plt_x = jnp.linspace(0, 1, 1000).reshape(-1, 1)
    forrester_y = forrester(plt_x.squeeze(axis=-1))
    utility_fn = dm.current_utility_functions[0]
    sample_y = -utility_fn(plt_x)

    latent_dist = posterior.predict(plt_x, train_data=dataset)
    predictive_dist = posterior.likelihood(latent_dist)

    predictive_mean = predictive_dist.mean()
    predictive_std = predictive_dist.stddev()

    fig, ax = plt.subplots()
    ax.plot(plt_x.squeeze(), predictive_mean, label="Predictive Mean", color=cols[1])
    ax.fill_between(
        plt_x.squeeze(),
        predictive_mean - 2 * predictive_std,
        predictive_mean + 2 * predictive_std,
        alpha=0.2,
        label="Two sigma",
        color=cols[1],
    )
    ax.plot(
        plt_x.squeeze(),
        predictive_mean - 2 * predictive_std,
        linestyle="--",
        linewidth=1,
        color=cols[1],
    )
    ax.plot(
        plt_x.squeeze(),
        predictive_mean + 2 * predictive_std,
        linestyle="--",
        linewidth=1,
        color=cols[1],
    )
    ax.plot(plt_x.squeeze(), sample_y, label="Posterior Sample")
    ax.plot(
        plt_x.squeeze(),
        forrester_y,
        label="Forrester Function",
        color=cols[0],
        linestyle="--",
        linewidth=2,
    )
    ax.axvline(x=0.757, linestyle=":", color=cols[3], label="True Optimum")
    ax.scatter(dataset.X, dataset.y, label="Observations", color=cols[2], zorder=2)
    ax.scatter(
        last_queried_points[0],
        -utility_fn(last_queried_points[0][None, ...]),
        label="Posterior Sample Optimum",
        marker="*",
        color=cols[3],
        zorder=3,
    )
    ax.legend(loc="center left", bbox_to_anchor=(0.950, 0.5))
    plt.show()

dm = UtilityDrivenDecisionMaker(
    search_space=search_space,
    posterior_handlers=posterior_handlers,
    datasets=initial_datasets,
    utility_function_builder=utility_function_builder,
    utility_maximizer=acquisition_maximizer,
    batch_size=1,
    key=key,
    post_ask=[plot_bo_iteration],
    post_tell=[],
)

results = dm.run(
    6,
    black_box_function_evaluator=function_evaluator,
)

%reload_ext watermark
%watermark -n -u -v -iv -w -a 'Thomas Christie'

Count data regression

import blackjax
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import matplotlib as mpl
import matplotlib.pyplot as plt
import tensorflow_probability.substrates.jax as tfp
from jax import config
from jaxtyping import install_import_hook

with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx

# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
tfd = tfp.distributions
key = jr.key(123)
plt.style.use(
    "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]

key, subkey = jr.split(key)
n = 50
x = jr.uniform(key, shape=(n, 1), minval=-2.0, maxval=2.0)
f = lambda x: 2.0 * jnp.sin(3 * x) + 0.5 * x  # latent function
y = jr.poisson(key, jnp.exp(f(x)))

D = gpx.Dataset(X=x, y=y)

xtest = jnp.linspace(-2.0, 2.0, 500).reshape(-1, 1)

fig, ax = plt.subplots()
ax.plot(x, y, "o", label="Observations", color=cols[1])
ax.plot(xtest, jnp.exp(f(xtest)), label=r"Rate $\lambda$")
ax.legend()

kernel = gpx.kernels.RBF()
meanf = gpx.mean_functions.Constant()
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
likelihood = gpx.likelihoods.Poisson(num_datapoints=D.n)

posterior = prior * likelihood
print(type(posterior))

# Adapted from BlackJax's introduction notebook.
num_adapt = 100
num_samples = 200

lpd = jax.jit(gpx.objectives.LogPosteriorDensity(negative=False))
unconstrained_lpd = jax.jit(lambda tree: lpd(tree.constrain(), D))

adapt = blackjax.window_adaptation(
    blackjax.nuts, unconstrained_lpd, num_adapt, target_acceptance_rate=0.65
)

# Initialise the chain
last_state, kernel, _ = adapt.run(key, posterior.unconstrain())


def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, info = kernel(rng_key, state)
        return state, (state, info)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

    return states, infos


# Sample from the posterior distribution
states, infos = inference_loop(key, kernel, last_state, num_samples)

acceptance_rate = jnp.mean(infos.acceptance_probability)
print(f"Acceptance rate: {acceptance_rate:.2f}")

fig, (ax0, ax1, ax2) = plt.subplots(ncols=3, figsize=(10, 3))
ax0.plot(states.position.constrain().prior.kernel.variance)
ax1.plot(states.position.constrain().prior.kernel.lengthscale)
ax2.plot(states.position.constrain().prior.mean_function.constant)
ax0.set_title("Kernel variance")
ax1.set_title("Kernel lengthscale")
ax2.set_title("Mean function constant")

thin_factor = 10
samples = []

for i in range(num_adapt, num_samples + num_adapt, thin_factor):
    sample = jtu.tree_map(lambda samples: samples[i], states.position)
    sample = sample.constrain()
    latent_dist = sample.predict(xtest, train_data=D)
    predictive_dist = sample.likelihood(latent_dist)
    samples.append(predictive_dist.sample(seed=key, sample_shape=(10,)))

samples = jnp.vstack(samples)

lower_ci, upper_ci = jnp.percentile(samples, jnp.array([2.5, 97.5]), axis=0)
expected_val = jnp.mean(samples, axis=0)

fig, ax = plt.subplots()
ax.plot(
    x, y, "o", markersize=5, color=cols[1], label="Observations", zorder=2, alpha=0.7
)
ax.plot(
    xtest, expected_val, linewidth=2, color=cols[0], label="Predicted mean", zorder=1
)
ax.fill_between(
    xtest.flatten(),
    lower_ci.flatten(),
    upper_ci.flatten(),
    alpha=0.2,
    color=cols[0],
    label="95% CI",
)

%load_ext watermark
%watermark -n -u -v -iv -w -a "Francesco Zanetta"

Introduction to Bayesian Optimisation

# Enable Float64 for more stable matrix inversions.
from jax import config

config.update("jax_enable_x64", True)

import jax
from jax import jit
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import install_import_hook, Float, Int
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import cm
import optax as ox
import tensorflow_probability.substrates.jax as tfp
from typing import List, Tuple

with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx
from gpjax.typing import Array, FunctionalSample, ScalarFloat
from jaxopt import ScipyBoundedMinimize

key = jr.key(42)
plt.style.use(
    "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]

def standardised_forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
    mean = 0.45321
    std = 4.4258
    return ((6 * x - 2) ** 2 * jnp.sin(12 * x - 4) - mean) / std

lower_bound = jnp.array([0.0])
upper_bound = jnp.array([1.0])
initial_sample_num = 5

initial_x = tfp.mcmc.sample_halton_sequence(
    dim=1, num_results=initial_sample_num, seed=key, dtype=jnp.float64
).reshape(-1, 1)
initial_y = standardised_forrester(initial_x)
D = gpx.Dataset(X=initial_x, y=initial_y)

def return_optimised_posterior(
    data: gpx.Dataset, prior: gpx.base.Module, key: Array
) -> gpx.base.Module:
    likelihood = gpx.likelihoods.Gaussian(
        num_datapoints=data.n, obs_stddev=jnp.array(1e-6)
    )  # Our function is noise-free, so we set the observation noise's standard deviation to a very small value
    likelihood = likelihood.replace_trainable(obs_stddev=False)

    posterior = prior * likelihood

    negative_mll = gpx.objectives.ConjugateMLL(negative=True)
    negative_mll(posterior, train_data=data)
    negative_mll = jit(negative_mll)

    opt_posterior, _ = gpx.fit(
        model=posterior,
        objective=negative_mll,
        train_data=data,
        optim=ox.adam(learning_rate=0.01),
        num_iters=1000,
        safe=True,
        key=key,
        verbose=False,
    )

    return opt_posterior


mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.Matern52()
prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
opt_posterior = return_optimised_posterior(D, prior, key)

approx_sample = opt_posterior.sample_approx(
    num_samples=1, train_data=D, key=key, num_features=500
)
utility_fn = lambda x: approx_sample(x)[0][0]

def optimise_sample(
    sample: FunctionalSample,
    key: Int[Array, ""],
    lower_bound: Float[Array, "D"],
    upper_bound: Float[Array, "D"],
    num_initial_sample_points: int,
) -> ScalarFloat:
    initial_sample_points = jr.uniform(
        key,
        shape=(num_initial_sample_points, lower_bound.shape[0]),
        dtype=jnp.float64,
        minval=lower_bound,
        maxval=upper_bound,
    )
    initial_sample_y = sample(initial_sample_points)
    best_x = jnp.array([initial_sample_points[jnp.argmin(initial_sample_y)]])

    # We want to maximise the utility function, but the optimiser performs minimisation. Since we're minimising the sample drawn, the sample is actually the negative utility function.
    negative_utility_fn = lambda x: sample(x)[0][0]
    lbfgsb = ScipyBoundedMinimize(fun=negative_utility_fn, method="l-bfgs-b")
    bounds = (lower_bound, upper_bound)
    x_star = lbfgsb.run(best_x, bounds=bounds).params
    return x_star


x_star = optimise_sample(approx_sample, key, lower_bound, upper_bound, 100)
y_star = standardised_forrester(x_star)

def plot_bayes_opt(
    posterior: gpx.base.Module,
    sample: FunctionalSample,
    dataset: gpx.Dataset,
    queried_x: ScalarFloat,
) -> None:
    plt_x = jnp.linspace(0, 1, 1000).reshape(-1, 1)
    forrester_y = standardised_forrester(plt_x)
    sample_y = sample(plt_x)

    latent_dist = posterior.predict(plt_x, train_data=dataset)
    predictive_dist = posterior.likelihood(latent_dist)

    predictive_mean = predictive_dist.mean()
    predictive_std = predictive_dist.stddev()

    fig, ax = plt.subplots()
    ax.plot(plt_x, predictive_mean, label="Predictive Mean", color=cols[1])
    ax.fill_between(
        plt_x.squeeze(),
        predictive_mean - 2 * predictive_std,
        predictive_mean + 2 * predictive_std,
        alpha=0.2,
        label="Two sigma",
        color=cols[1],
    )
    ax.plot(
        plt_x,
        predictive_mean - 2 * predictive_std,
        linestyle="--",
        linewidth=1,
        color=cols[1],
    )
    ax.plot(
        plt_x,
        predictive_mean + 2 * predictive_std,
        linestyle="--",
        linewidth=1,
        color=cols[1],
    )
    ax.plot(plt_x, sample_y, label="Posterior Sample")
    ax.plot(
        plt_x,
        forrester_y,
        label="Forrester Function",
        color=cols[0],
        linestyle="--",
        linewidth=2,
    )
    ax.axvline(x=0.757, linestyle=":", color=cols[3], label="True Optimum")
    ax.scatter(dataset.X, dataset.y, label="Observations", color=cols[2], zorder=2)
    ax.scatter(
        queried_x,
        sample(queried_x),
        label="Posterior Sample Optimum",
        marker="*",
        color=cols[3],
        zorder=3,
    )
    ax.legend(loc="center left", bbox_to_anchor=(0.975, 0.5))
    plt.show()


plot_bayes_opt(opt_posterior, approx_sample, D, x_star)

bo_iters = 5

# Set up initial dataset
initial_x = tfp.mcmc.sample_halton_sequence(
    dim=1, num_results=initial_sample_num, seed=key, dtype=jnp.float64
).reshape(-1, 1)
initial_y = standardised_forrester(initial_x)
D = gpx.Dataset(X=initial_x, y=initial_y)

for i in range(bo_iters):
    key, subkey = jr.split(key)

    # Generate optimised posterior using previously observed data
    mean = gpx.mean_functions.Zero()
    kernel = gpx.kernels.Matern52()
    prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
    opt_posterior = return_optimised_posterior(D, prior, subkey)

    # Draw a sample from the posterior, and find the minimiser of it
    approx_sample = opt_posterior.sample_approx(
        num_samples=1, train_data=D, key=subkey, num_features=500
    )
    x_star = optimise_sample(
        approx_sample, subkey, lower_bound, upper_bound, num_initial_sample_points=100
    )

    plot_bayes_opt(opt_posterior, approx_sample, D, x_star)

    # Evaluate the black-box function at the best point observed so far, and add it to the dataset
    y_star = standardised_forrester(x_star)
    print(f"Queried Point: {x_star}, Black-Box Function Value: {y_star}")
    D = D + gpx.Dataset(X=x_star, y=y_star)

fig, ax = plt.subplots()
fn_evaluations = jnp.arange(1, bo_iters + initial_sample_num + 1)
cumulative_best_y = jax.lax.associative_scan(jax.numpy.minimum, D.y)
ax.plot(fn_evaluations, cumulative_best_y)
ax.axvline(x=initial_sample_num, linestyle=":")
ax.axhline(y=-1.463, linestyle="--", label="True Minimum")
ax.set_xlabel("Number of Black-Box Function Evaluations")
ax.set_ylabel("Best Observed Value")
ax.legend()
plt.show()

def standardised_six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]:
    mean = 1.12767
    std = 1.17500
    x1 = x[..., :1]
    x2 = x[..., 1:]
    term1 = (4 - 2.1 * x1**2 + x1**4 / 3) * x1**2
    term2 = x1 * x2
    term3 = (-4 + 4 * x2**2) * x2**2
    return (term1 + term2 + term3 - mean) / std

x1 = jnp.linspace(-2, 2, 100)
x2 = jnp.linspace(-1, 1, 100)
x1, x2 = jnp.meshgrid(x1, x2)
x = jnp.stack([x1.flatten(), x2.flatten()], axis=1)
y = standardised_six_hump_camel(x)

fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
surf = ax.plot_surface(
    x1,
    x2,
    y.reshape(x1.shape[0], x2.shape[0]),
    linewidth=0,
    cmap=cm.coolwarm,
    antialiased=False,
)
ax.set_xlabel("x1")
ax.set_ylabel("x2")
plt.show()

x_star_one = jnp.array([[0.0898, -0.7126]])
x_star_two = jnp.array([[-0.0898, 0.7126]])
fig, ax = plt.subplots()
contour_plot = ax.contourf(
    x1, x2, y.reshape(x1.shape[0], x2.shape[0]), cmap=cm.coolwarm, levels=40
)
ax.scatter(
    x_star_one[0][0], x_star_one[0][1], marker="*", color=cols[2], label="Global Minima"
)
ax.scatter(x_star_two[0][0], x_star_two[0][1], marker="*", color=cols[2])
ax.set_xlabel("x1")
ax.set_ylabel("x2")
fig.colorbar(contour_plot)
ax.legend()
plt.show()

lower_bound = jnp.array([-2.0, -1.0])
upper_bound = jnp.array([2.0, 1.0])
initial_sample_num = 5
bo_iters = 12
num_experiments = 5
bo_experiment_results = []

for experiment in range(num_experiments):
    print(f"Starting Experiment: {experiment + 1}")
    # Set up initial dataset
    initial_x = tfp.mcmc.sample_halton_sequence(
        dim=2, num_results=initial_sample_num, seed=key, dtype=jnp.float64
    )
    initial_x = jnp.array(lower_bound + (upper_bound - lower_bound) * initial_x)
    initial_y = standardised_six_hump_camel(initial_x)
    D = gpx.Dataset(X=initial_x, y=initial_y)

    for i in range(bo_iters):
        key, subkey = jr.split(key)

        # Generate optimised posterior
        mean = gpx.mean_functions.Zero()
        kernel = gpx.kernels.Matern52(
            active_dims=[0, 1], lengthscale=jnp.array([1.0, 1.0]), variance=2.0
        )
        prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
        opt_posterior = return_optimised_posterior(D, prior, subkey)

        # Draw a sample from the posterior, and find the minimiser of it
        approx_sample = opt_posterior.sample_approx(
            num_samples=1, train_data=D, key=subkey, num_features=500
        )
        x_star = optimise_sample(
            approx_sample,
            subkey,
            lower_bound,
            upper_bound,
            num_initial_sample_points=1000,
        )

        # Evaluate the black-box function at the best point observed so far, and add it to the dataset
        y_star = standardised_six_hump_camel(x_star)
        print(
            f"BO Iteration: {i + 1}, Queried Point: {x_star}, Black-Box Function Value:"
            f" {y_star}"
        )
        D = D + gpx.Dataset(X=x_star, y=y_star)
    bo_experiment_results.append(D)

random_experiment_results = []
for i in range(num_experiments):
    key, subkey = jr.split(key)
    initial_x = bo_experiment_results[i].X[:5]
    initial_y = bo_experiment_results[i].y[:5]
    final_x = jr.uniform(
        key,
        shape=(bo_iters, 2),
        dtype=jnp.float64,
        minval=lower_bound,
        maxval=upper_bound,
    )
    final_y = standardised_six_hump_camel(final_x)
    random_x = jnp.concatenate([initial_x, final_x], axis=0)
    random_y = jnp.concatenate([initial_y, final_y], axis=0)
    random_experiment_results.append(gpx.Dataset(X=random_x, y=random_y))

def obtain_log_regret_statistics(
    experiment_results: List[gpx.Dataset],
    global_minimum: ScalarFloat,
) -> Tuple[Float[Array, "N 1"], Float[Array, "N 1"]]:
    log_regret_results = []
    for exp_result in experiment_results:
        observations = exp_result.y
        cumulative_best_observations = jax.lax.associative_scan(
            jax.numpy.minimum, observations
        )
        regret = cumulative_best_observations - global_minimum
        log_regret = jnp.log(regret)
        log_regret_results.append(log_regret)

    log_regret_results = jnp.array(log_regret_results)
    log_regret_mean = jnp.mean(log_regret_results, axis=0)
    log_regret_std = jnp.std(log_regret_results, axis=0)
    return log_regret_mean, log_regret_std


bo_log_regret_mean, bo_log_regret_std = obtain_log_regret_statistics(
    bo_experiment_results, -1.8377
)
(
    random_log_regret_mean,
    random_log_regret_std,
) = obtain_log_regret_statistics(random_experiment_results, -1.8377)

fig, ax = plt.subplots()
fn_evaluations = jnp.arange(1, bo_iters + initial_sample_num + 1)
ax.plot(fn_evaluations, bo_log_regret_mean, label="Bayesian Optimisation")
ax.fill_between(
    fn_evaluations,
    bo_log_regret_mean[:, 0] - bo_log_regret_std[:, 0],
    bo_log_regret_mean[:, 0] + bo_log_regret_std[:, 0],
    alpha=0.2,
)
ax.plot(fn_evaluations, random_log_regret_mean, label="Random Search")
ax.fill_between(
    fn_evaluations,
    random_log_regret_mean[:, 0] - random_log_regret_std[:, 0],
    random_log_regret_mean[:, 0] + random_log_regret_std[:, 0],
    alpha=0.2,
)
ax.axvline(x=initial_sample_num, linestyle=":")
ax.set_xlabel("Number of Black-Box Function Evaluations")
ax.set_ylabel("Log Regret")
ax.legend()
plt.show()

fig, ax = plt.subplots()
contour_plot = ax.contourf(
    x1, x2, y.reshape(x1.shape[0], x2.shape[0]), cmap=cm.coolwarm, levels=40
)
ax.scatter(
    x_star_one[0][0],
    x_star_one[0][1],
    marker="*",
    color=cols[2],
    label="Global Minimum",
    zorder=2,
)
ax.scatter(x_star_two[0][0], x_star_two[0][1], marker="*", color=cols[2], zorder=2)
ax.scatter(
    bo_experiment_results[1].X[:, 0],
    bo_experiment_results[1].X[:, 1],
    marker="x",
    color=cols[1],
    label="Bayesian Optimisation Queries",
)
ax.set_xlabel("x1")
ax.set_ylabel("x2")
fig.colorbar(contour_plot)
ax.legend()
plt.show()

%reload_ext watermark
%watermark -n -u -v -iv -w -a 'Thomas Christie'

Kernel Guide

# Enable Float64 for more stable matrix inversions.
from jax import config

config.update("jax_enable_x64", True)

from dataclasses import dataclass
from typing import Dict

from jax import jit
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import (
    Array,
    Float,
    install_import_hook,
)
import matplotlib.pyplot as plt
import numpy as np
from simple_pytree import static_field
import tensorflow_probability.substrates.jax as tfp

with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx
    from gpjax.base.param import param_field

key = jr.key(123)
tfb = tfp.bijectors
plt.style.use(
    "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = plt.rcParams["axes.prop_cycle"].by_key()["color"]

kernels = [
    gpx.kernels.Matern12(),
    gpx.kernels.Matern32(),
    gpx.kernels.Matern52(),
    gpx.kernels.RBF(),
    gpx.kernels.Polynomial(),
    gpx.kernels.Polynomial(degree=2),
]
fig, axes = plt.subplots(ncols=3, nrows=2, figsize=(10, 6), tight_layout=True)

x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1)

meanf = gpx.mean_functions.Zero()

for k, ax in zip(kernels, axes.ravel()):
    prior = gpx.gps.Prior(mean_function=meanf, kernel=k)
    rv = prior(x)
    y = rv.sample(seed=key, sample_shape=(10,))
    ax.plot(x, y.T, alpha=0.7)
    ax.set_title(k.name)

slice_kernel = gpx.kernels.RBF(active_dims=[0, 1, 3], lengthscale=jnp.ones((3,)))

print(f"Lengthscales: {slice_kernel.lengthscale}")

# Inputs
x_matrix = jr.normal(key, shape=(50, 5))

# Compute the Gram matrix
K = slice_kernel.gram(x_matrix)
print(K.shape)

k1 = gpx.kernels.RBF()
k2 = gpx.kernels.Polynomial()
sum_k = gpx.kernels.SumKernel(kernels=[k1, k2])

fig, ax = plt.subplots(ncols=3, figsize=(9, 3))
im0 = ax[0].matshow(k1.gram(x).to_dense())
im1 = ax[1].matshow(k2.gram(x).to_dense())
im2 = ax[2].matshow(sum_k.gram(x).to_dense())

fig.colorbar(im0, ax=ax[0], fraction=0.05)
fig.colorbar(im1, ax=ax[1], fraction=0.05)
fig.colorbar(im2, ax=ax[2], fraction=0.05)

k3 = gpx.kernels.Matern32()

prod_k = gpx.kernels.ProductKernel(kernels=[k1, k2, k3])

fig, ax = plt.subplots(ncols=4, figsize=(12, 3))
im0 = ax[0].matshow(k1.gram(x).to_dense())
im1 = ax[1].matshow(k2.gram(x).to_dense())
im2 = ax[2].matshow(k3.gram(x).to_dense())
im3 = ax[3].matshow(prod_k.gram(x).to_dense())

fig.colorbar(im0, ax=ax[0], fraction=0.05)
fig.colorbar(im1, ax=ax[1], fraction=0.05)
fig.colorbar(im2, ax=ax[2], fraction=0.05)
fig.colorbar(im3, ax=ax[3], fraction=0.05)

def angular_distance(x, y, c):
    return jnp.abs((x - y + c) % (c * 2) - c)


bij = tfb.SoftClip(low=jnp.array(4.0, dtype=jnp.float64))


@dataclass
class Polar(gpx.kernels.AbstractKernel):
    period: float = static_field(2 * jnp.pi)
    tau: float = param_field(jnp.array([5.0]), bijector=bij)

    def __call__(
        self, x: Float[Array, "1 D"], y: Float[Array, "1 D"]
    ) -> Float[Array, "1"]:
        c = self.period / 2.0
        t = angular_distance(x, y, c)
        K = (1 + self.tau * t / c) * jnp.clip(1 - t / c, 0, jnp.inf) ** self.tau
        return K.squeeze()

# Simulate data
angles = jnp.linspace(0, 2 * jnp.pi, num=200).reshape(-1, 1)
n = 20
noise = 0.2

X = jnp.sort(jr.uniform(key, minval=0.0, maxval=jnp.pi * 2, shape=(n, 1)), axis=0)
y = 4 + jnp.cos(2 * X) + jr.normal(key, shape=X.shape) * noise

D = gpx.Dataset(X=X, y=y)

# Define polar Gaussian process
PKern = Polar()
meanf = gpx.mean_functions.Zero()
likelihood = gpx.likelihoods.Gaussian(num_datapoints=n)
circular_posterior = gpx.gps.Prior(mean_function=meanf, kernel=PKern) * likelihood

# Optimise GP's marginal log-likelihood using BFGS
opt_posterior, history = gpx.fit_scipy(
    model=circular_posterior,
    objective=jit(gpx.objectives.ConjugateMLL(negative=True)),
    train_data=D,
)

posterior_rv = opt_posterior.likelihood(opt_posterior.predict(angles, train_data=D))
mu = posterior_rv.mean()
one_sigma = posterior_rv.stddev()

fig = plt.figure(figsize=(7, 3.5))
gridspec = fig.add_gridspec(1, 1)
ax = plt.subplot(gridspec[0], polar=True)

ax.fill_between(
    angles.squeeze(),
    mu - one_sigma,
    mu + one_sigma,
    alpha=0.3,
    label=r"1 Posterior s.d.",
    color=cols[1],
    lw=0,
)
ax.fill_between(
    angles.squeeze(),
    mu - 3 * one_sigma,
    mu + 3 * one_sigma,
    alpha=0.15,
    label=r"3 Posterior s.d.",
    color=cols[1],
    lw=0,
)
ax.plot(angles, mu, label="Posterior mean")
ax.scatter(D.X, D.y, alpha=1, label="Observations")
ax.legend()

%reload_ext watermark
%watermark -n -u -v -iv -w -a 'Thomas Pinder'

Sparse Gaussian Process Regression

# Enable Float64 for more stable matrix inversions.
from jax import config

config.update("jax_enable_x64", True)

from jax import jit
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import install_import_hook
import matplotlib as mpl
import matplotlib.pyplot as plt
import optax as ox
from docs.examples.utils import clean_legend

with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx

key = jr.key(123)
plt.style.use(
    "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]

n = 2500
noise = 0.5

key, subkey = jr.split(key)
x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,)).reshape(-1, 1)
f = lambda x: jnp.sin(2 * x) + x * jnp.cos(5 * x)
signal = f(x)
y = signal + jr.normal(subkey, shape=signal.shape) * noise

D = gpx.Dataset(X=x, y=y)

xtest = jnp.linspace(-3.1, 3.1, 500).reshape(-1, 1)
ytest = f(xtest)

n_inducing = 50
z = jnp.linspace(-3.0, 3.0, n_inducing).reshape(-1, 1)

fig, ax = plt.subplots()
ax.scatter(x, y, alpha=0.25, label="Observations", color=cols[0])
ax.plot(xtest, ytest, label="Latent function", linewidth=2, color=cols[1])
ax.vlines(
    x=z,
    ymin=y.min(),
    ymax=y.max(),
    alpha=0.3,
    linewidth=0.5,
    label="Inducing point",
    color=cols[2],
)
ax.legend(loc="best")
plt.show()

meanf = gpx.mean_functions.Constant()
kernel = gpx.kernels.RBF()
likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
posterior = prior * likelihood

q = gpx.variational_families.CollapsedVariationalGaussian(
    posterior=posterior, inducing_inputs=z
)

elbo = gpx.objectives.CollapsedELBO(negative=True)

print(gpx.cite(elbo))


elbo = jit(elbo)

opt_posterior, history = gpx.fit(
    model=q,
    objective=elbo,
    train_data=D,
    optim=ox.adamw(learning_rate=1e-2),
    num_iters=500,
    key=key,
)

fig, ax = plt.subplots()
ax.plot(history, color=cols[1])
ax.set(xlabel="Training iterate", ylabel="ELBO")

latent_dist = opt_posterior(xtest, train_data=D)
predictive_dist = opt_posterior.posterior.likelihood(latent_dist)

inducing_points = opt_posterior.inducing_inputs

samples = latent_dist.sample(seed=key, sample_shape=(20,))

predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()

fig, ax = plt.subplots()

ax.plot(x, y, "x", label="Observations", color=cols[0], alpha=0.1)
ax.plot(
    xtest,
    ytest,
    label="Latent function",
    color=cols[1],
    linestyle="-",
    linewidth=1,
)
ax.plot(xtest, predictive_mean, label="Predictive mean", color=cols[1])

ax.fill_between(
    xtest.squeeze(),
    predictive_mean - 2 * predictive_std,
    predictive_mean + 2 * predictive_std,
    alpha=0.2,
    color=cols[1],
    label="Two sigma",
)
ax.plot(
    xtest,
    predictive_mean - 2 * predictive_std,
    color=cols[1],
    linestyle="--",
    linewidth=0.5,
)
ax.plot(
    xtest,
    predictive_mean + 2 * predictive_std,
    color=cols[1],
    linestyle="--",
    linewidth=0.5,
)


ax.vlines(
    x=inducing_points,
    ymin=ytest.min(),
    ymax=ytest.max(),
    alpha=0.3,
    linewidth=0.5,
    label="Inducing point",
    color=cols[2],
)
ax.legend()
ax.set(xlabel=r"$x$", ylabel=r"$f(x)$")
plt.show()

full_rank_model = gpx.gps.Prior(
    mean_function=gpx.mean_functions.Zero(), kernel=gpx.kernels.RBF()
) * gpx.likelihoods.Gaussian(num_datapoints=D.n)
negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True).step)
%timeit negative_mll(full_rank_model, D).block_until_ready()

negative_elbo = jit(gpx.objectives.CollapsedELBO(negative=True).step)
%timeit negative_elbo(q, D).block_until_ready()

%reload_ext watermark
%watermark -n -u -v -iv -w -a 'Daniel Dodd'

Likelihood guide

# Enable Float64 for more stable matrix inversions.
from jax import config

config.update("jax_enable_x64", True)

import gpjax as gpx
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import tensorflow_probability.substrates.jax as tfp

tfd = tfp.distributions
plt.style.use(
    "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = plt.rcParams["axes.prop_cycle"].by_key()["color"]
key = jr.key(123)


n = 50
x = jnp.sort(jr.uniform(key=key, shape=(n, 1), minval=-3.0, maxval=3.0), axis=0)
xtest = jnp.linspace(-3, 3, 100)[:, None]
f = lambda x: jnp.sin(x)
y = f(x) + 0.1 * jr.normal(key, shape=x.shape)
D = gpx.Dataset(x, y)

fig, ax = plt.subplots()
ax.plot(x, y, "o", label="Observations")
ax.plot(x, f(x), label="Latent function")
ax.legend()

gpx.likelihoods.Gaussian(num_datapoints=D.n)

gpx.likelihoods.Gaussian(num_datapoints=D.n, obs_stddev=0.5)

kernel = gpx.kernels.Matern32()
meanf = gpx.mean_functions.Zero()
prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)

likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n, obs_stddev=0.1)

posterior = prior * likelihood

latent_dist = posterior.predict(xtest, D)

fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(9, 2))
key, subkey = jr.split(key)

for ax in axes.ravel():
    subkey, _ = jr.split(subkey)
    ax.plot(
        latent_dist.sample(sample_shape=(1,), seed=subkey).T,
        lw=1,
        color=cols[0],
        label="Latent samples",
    )
    ax.plot(
        likelihood.predict(latent_dist).sample(sample_shape=(1,), seed=subkey).T,
        "o",
        markersize=5,
        alpha=0.3,
        color=cols[1],
        label="Predictive samples",
    )

likelihood = gpx.likelihoods.Bernoulli(num_datapoints=D.n)


fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(9, 2))
key, subkey = jr.split(key)

for ax in axes.ravel():
    subkey, _ = jr.split(subkey)
    ax.plot(
        latent_dist.sample(sample_shape=(1,), seed=subkey).T,
        lw=1,
        color=cols[0],
        label="Latent samples",
    )
    ax.plot(
        likelihood.predict(latent_dist).sample(sample_shape=(1,), seed=subkey).T,
        "o",
        markersize=3,
        alpha=0.5,
        color=cols[1],
        label="Predictive samples",
    )

z = jnp.linspace(-3.0, 3.0, 10).reshape(-1, 1)
q = gpx.variational_families.VariationalGaussian(posterior=posterior, inducing_inputs=z)


def q_moments(x):
    qx = q(x)
    return qx.mean(), qx.variance()


mean, variance = jax.vmap(q_moments)(x[:, None])

jnp.sum(likelihood.expected_log_likelihood(y=y, mean=mean, variance=variance))

lquad = gpx.likelihoods.Gaussian(
    num_datapoints=D.n,
    obs_stddev=jnp.array([0.1]),
    integrator=gpx.integrators.GHQuadratureIntegrator(num_points=20),
)

%reload_ext watermark
%watermark -n -u -v -iv -w -a 'Thomas Pinder'

Regression

# Enable Float64 for more stable matrix inversions.
from jax import config

config.update("jax_enable_x64", True)

from jax import jit
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import install_import_hook
import matplotlib as mpl
import matplotlib.pyplot as plt
import optax as ox
from docs.examples.utils import clean_legend

with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx

key = jr.key(123)
plt.style.use(
    "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]

n = 100
noise = 0.3

key, subkey = jr.split(key)
x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,)).reshape(-1, 1)
f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)
signal = f(x)
y = signal + jr.normal(subkey, shape=signal.shape) * noise

D = gpx.Dataset(X=x, y=y)

xtest = jnp.linspace(-3.5, 3.5, 500).reshape(-1, 1)
ytest = f(xtest)

fig, ax = plt.subplots()
ax.plot(x, y, "o", label="Observations", color=cols[0])
ax.plot(xtest, ytest, label="Latent function", color=cols[1])
ax.legend(loc="best")

kernel = gpx.kernels.RBF()
meanf = gpx.mean_functions.Zero()
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)

prior_dist = prior.predict(xtest)

prior_mean = prior_dist.mean()
prior_std = prior_dist.variance()
samples = prior_dist.sample(seed=key, sample_shape=(20,))


fig, ax = plt.subplots()
ax.plot(xtest, samples.T, alpha=0.5, color=cols[0], label="Prior samples")
ax.plot(xtest, prior_mean, color=cols[1], label="Prior mean")
ax.fill_between(
    xtest.flatten(),
    prior_mean - prior_std,
    prior_mean + prior_std,
    alpha=0.3,
    color=cols[1],
    label="Prior variance",
)
ax.legend(loc="best")
ax = clean_legend(ax)

likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)

posterior = prior * likelihood

negative_mll = gpx.objectives.ConjugateMLL(negative=True)
negative_mll(posterior, train_data=D)


# static_tree = jax.tree_map(lambda x: not(x), posterior.trainables)
# optim = ox.chain(
#     ox.adam(learning_rate=0.01),
#     ox.masked(ox.set_to_zero(), static_tree)
#     )

print(gpx.cite(negative_mll))

negative_mll = jit(negative_mll)

opt_posterior, history = gpx.fit_scipy(
    model=posterior,
    objective=negative_mll,
    train_data=D,
)

latent_dist = opt_posterior.predict(xtest, train_data=D)
predictive_dist = opt_posterior.likelihood(latent_dist)

predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()

fig, ax = plt.subplots(figsize=(7.5, 2.5))
ax.plot(x, y, "x", label="Observations", color=cols[0], alpha=0.5)
ax.fill_between(
    xtest.squeeze(),
    predictive_mean - 2 * predictive_std,
    predictive_mean + 2 * predictive_std,
    alpha=0.2,
    label="Two sigma",
    color=cols[1],
)
ax.plot(
    xtest,
    predictive_mean - 2 * predictive_std,
    linestyle="--",
    linewidth=1,
    color=cols[1],
)
ax.plot(
    xtest,
    predictive_mean + 2 * predictive_std,
    linestyle="--",
    linewidth=1,
    color=cols[1],
)
ax.plot(
    xtest, ytest, label="Latent function", color=cols[0], linestyle="--", linewidth=2
)
ax.plot(xtest, predictive_mean, label="Predictive mean", color=cols[1])
ax.legend(loc="center left", bbox_to_anchor=(0.975, 0.5))

%reload_ext watermark
%watermark -n -u -v -iv -w -a 'Thomas Pinder & Daniel Dodd'

Graph Kernels

# Enable Float64 for more stable matrix inversions.
from jax import config

config.update("jax_enable_x64", True)

import random

from jax import jit
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import install_import_hook
import matplotlib as mpl
import matplotlib.pyplot as plt
import networkx as nx
import optax as ox

with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx

key = jr.key(123)
plt.style.use(
    "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]

vertex_per_side = 20
n_edges_to_remove = 30
p = 0.8

G = nx.barbell_graph(vertex_per_side, 0)

random.seed(123)
[G.remove_edge(*i) for i in random.sample(list(G.edges), n_edges_to_remove)]

pos = nx.spring_layout(G, seed=123)  # positions for all nodes

nx.draw(
    G, pos, node_size=100, node_color=cols[1], edge_color="black", with_labels=False
)

L = nx.laplacian_matrix(G).toarray()

x = jnp.arange(G.number_of_nodes()).reshape(-1, 1)

true_kernel = gpx.kernels.GraphKernel(
    laplacian=L,
    lengthscale=2.3,
    variance=3.2,
    smoothness=6.1,
)
prior = gpx.gps.Prior(mean_function=gpx.mean_functions.Zero(), kernel=true_kernel)

fx = prior(x)
y = fx.sample(seed=key, sample_shape=(1,)).reshape(-1, 1)

D = gpx.Dataset(X=x, y=y)

nx.draw(G, pos, node_color=y, with_labels=False, alpha=0.5)

vmin, vmax = y.min(), y.max()
sm = plt.cm.ScalarMappable(
    cmap=plt.cm.inferno, norm=plt.Normalize(vmin=vmin, vmax=vmax)
)
sm.set_array([])
ax = plt.gca()
cbar = plt.colorbar(sm, ax=ax)

likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
kernel = gpx.kernels.GraphKernel(laplacian=L)
prior = gpx.gps.Prior(mean_function=gpx.mean_functions.Zero(), kernel=kernel)
posterior = prior * likelihood

print(gpx.cite(kernel))

opt_posterior, training_history = gpx.fit_scipy(
    model=posterior,
    objective=gpx.objectives.ConjugateMLL(negative=True),
    train_data=D,
)

initial_dist = likelihood(posterior(x, D))
predictive_dist = opt_posterior.likelihood(opt_posterior(x, D))

initial_mean = initial_dist.mean()
learned_mean = predictive_dist.mean()

rmse = lambda ytrue, ypred: jnp.sum(jnp.sqrt(jnp.square(ytrue - ypred)))

initial_rmse = jnp.sum(jnp.sqrt(jnp.square(y.squeeze() - initial_mean)))
learned_rmse = jnp.sum(jnp.sqrt(jnp.square(y.squeeze() - learned_mean)))
print(
    f"RMSE with initial parameters: {initial_rmse: .2f}\nRMSE with learned parameters:"
    f" {learned_rmse: .2f}"
)

error = jnp.abs(learned_mean - y.squeeze())

nx.draw(G, pos, node_color=error, with_labels=False, alpha=0.5)

vmin, vmax = error.min(), error.max()
sm = plt.cm.ScalarMappable(
    cmap=plt.cm.inferno, norm=plt.Normalize(vmin=vmin, vmax=vmax)
)
ax = plt.gca()
cbar = plt.colorbar(sm, ax=ax)

%reload_ext watermark
%watermark -n -u -v -iv -w -a 'Thomas Pinder'

Gaussian Processes for Vector Fields and Ocean Current Modelling

from jax import config

config.update("jax_enable_x64", True)
from dataclasses import dataclass, field

from jax import hessian
from jax import config
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import (
    Array,
    Float,
    install_import_hook,
)
from matplotlib import rcParams
import matplotlib.pyplot as plt
import pandas as pd
import tensorflow_probability as tfp

with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx

# Enable Float64 for more stable matrix inversions.
key = jr.key(123)
plt.style.use(
    "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
colors = rcParams["axes.prop_cycle"].by_key()["color"]

# function to place data from csv into correct array shape
def prepare_data(df):
    pos = jnp.array([df["lon"], df["lat"]])
    vel = jnp.array([df["ubar"], df["vbar"]])
    # extract shape stored as 'metadata' in the test data
    try:
        shape = (int(df["shape"][1]), int(df["shape"][0]))  # shape = (34,16)
        return pos, vel, shape
    except KeyError:
        return pos, vel


# loading in data

gulf_data_train = pd.read_csv(
    "https://raw.githubusercontent.com/JaxGaussianProcesses/static/main/data/gulfdata_train.csv"
)
gulf_data_test = pd.read_csv(
    "https://raw.githubusercontent.com/JaxGaussianProcesses/static/main/data/gulfdata_test.csv"
)


pos_test, vel_test, shape = prepare_data(gulf_data_test)
pos_train, vel_train = prepare_data(gulf_data_train)

fig, ax = plt.subplots(1, 1, figsize=(6, 3))
ax.quiver(
    pos_test[0],
    pos_test[1],
    vel_test[0],
    vel_test[1],
    color=colors[0],
    label="Ocean Current",
    angles="xy",
    scale=10,
)
ax.quiver(
    pos_train[0],
    pos_train[1],
    vel_train[0],
    vel_train[1],
    color=colors[1],
    alpha=0.7,
    label="Drifter",
    angles="xy",
    scale=10,
)

ax.set(
    xlabel="Longitude",
    ylabel="Latitude",
)
ax.legend(
    framealpha=0.0,
    ncols=2,
    fontsize="medium",
    bbox_to_anchor=(0.5, -0.3),
    loc="lower center",
)
plt.show()

# Change vectors x -> X = (x,z), and vectors y -> Y = (y,z) via the artificial z label
def label_position(data):
    # introduce alternating z label
    n_points = len(data[0])
    label = jnp.tile(jnp.array([0.0, 1.0]), n_points)
    return jnp.vstack((jnp.repeat(data, repeats=2, axis=1), label)).T


# change vectors y -> Y by reshaping the velocity measurements
def stack_velocity(data):
    return data.T.flatten().reshape(-1, 1)


def dataset_3d(pos, vel):
    return gpx.Dataset(label_position(pos), stack_velocity(vel))


# label and place the training data into a Dataset object to be used by GPJax
dataset_train = dataset_3d(pos_train, vel_train)

# we also require the testing data to be relabelled for later use, such that we can query the 2Nx2N GP at the test points
dataset_ground_truth = dataset_3d(pos_test, vel_test)



@dataclass
class VelocityKernel(gpx.kernels.AbstractKernel):
    kernel0: gpx.kernels.AbstractKernel = field(
        default_factory=lambda: gpx.kernels.RBF(active_dims=[0, 1])
    )
    kernel1: gpx.kernels.AbstractKernel = field(
        default_factory=lambda: gpx.kernels.RBF(active_dims=[0, 1])
    )

    def __call__(
        self, X: Float[Array, "1 D"], Xp: Float[Array, "1 D"]
    ) -> Float[Array, "1"]:
        # standard RBF-SE kernel is x and x' are on the same output, otherwise returns 0

        z = jnp.array(X[2], dtype=int)
        zp = jnp.array(Xp[2], dtype=int)

        # achieve the correct value via 'switches' that are either 1 or 0
        k0_switch = ((z + 1) % 2) * ((zp + 1) % 2)
        k1_switch = z * zp

        return k0_switch * self.kernel0(X, Xp) + k1_switch * self.kernel1(X, Xp)

def initialise_gp(kernel, mean, dataset):
    prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
    likelihood = gpx.likelihoods.Gaussian(
        num_datapoints=dataset.n, obs_stddev=jnp.array([1.0e-3], dtype=jnp.float64)
    )
    posterior = prior * likelihood
    return posterior


# Define the velocity GP
mean = gpx.mean_functions.Zero()
kernel = VelocityKernel()
velocity_posterior = initialise_gp(kernel, mean, dataset_train)

def optimise_mll(posterior, dataset, NIters=1000, key=key):
    # define the MLL using dataset_train
    objective = gpx.objectives.ConjugateMLL(negative=True)
    # Optimise to minimise the MLL
    opt_posterior, history = gpx.fit_scipy(
        model=posterior,
        objective=objective,
        train_data=dataset,
    )
    return opt_posterior


opt_velocity_posterior = optimise_mll(velocity_posterior, dataset_train)

def latent_distribution(opt_posterior, pos_3d, dataset_train):
    latent = opt_posterior.predict(pos_3d, train_data=dataset_train)
    latent_mean = latent.mean()
    latent_std = latent.stddev()
    return latent_mean, latent_std


# extract latent mean and std of g, redistribute into vectors to model F
velocity_mean, velocity_std = latent_distribution(
    opt_velocity_posterior, dataset_ground_truth.X, dataset_train
)

dataset_latent_velocity = dataset_3d(pos_test, velocity_mean)

# Residuals between ground truth and estimate


def plot_vector_field(ax, dataset, **kwargs):
    ax.quiver(
        dataset.X[::2][:, 0],
        dataset.X[::2][:, 1],
        dataset.y[::2],
        dataset.y[1::2],
        **kwargs,
    )


def prepare_ax(ax, X, Y, title, **kwargs):
    ax.set(
        xlim=[X.min() - 0.1, X.max() + 0.1],
        ylim=[Y.min() + 0.1, Y.max() + 0.1],
        aspect="equal",
        title=title,
        ylabel="latitude",
        **kwargs,
    )


def residuals(dataset_latent, dataset_ground_truth):
    return jnp.sqrt(
        (dataset_latent.y[::2] - dataset_ground_truth.y[::2]) ** 2
        + (dataset_latent.y[1::2] - dataset_ground_truth.y[1::2]) ** 2
    )


def plot_fields(
    dataset_ground_truth, dataset_trajectory, dataset_latent, shape=shape, scale=10
):
    X = dataset_ground_truth.X[:, 0][::2]
    Y = dataset_ground_truth.X[:, 1][::2]
    # make figure
    fig, ax = plt.subplots(1, 3, figsize=(12.0, 3.0), sharey=True)

    # ground truth
    plot_vector_field(
        ax[0],
        dataset_ground_truth,
        color=colors[0],
        label="Ocean Current",
        angles="xy",
        scale=scale,
    )
    plot_vector_field(
        ax[0],
        dataset_trajectory,
        color=colors[1],
        label="Drifter",
        angles="xy",
        scale=scale,
    )
    prepare_ax(ax[0], X, Y, "Ground Truth", xlabel="Longitude")

    # Latent estimate of vector field F
    plot_vector_field(ax[1], dataset_latent, color=colors[0], angles="xy", scale=scale)
    plot_vector_field(
        ax[1], dataset_trajectory, color=colors[1], angles="xy", scale=scale
    )
    prepare_ax(ax[1], X, Y, "GP Estimate", xlabel="Longitude")

    # residuals
    residuals_vel = jnp.flip(
        residuals(dataset_latent, dataset_ground_truth).reshape(shape), axis=0
    )
    im = ax[2].imshow(
        residuals_vel,
        extent=[X.min(), X.max(), Y.min(), Y.max()],
        cmap="jet",
        vmin=0,
        vmax=1.0,
        interpolation="spline36",
    )
    plot_vector_field(
        ax[2], dataset_trajectory, color=colors[1], angles="xy", scale=scale
    )
    prepare_ax(ax[2], X, Y, "Residuals", xlabel="Longitude")
    fig.colorbar(im, fraction=0.027, pad=0.04, orientation="vertical")

    fig.legend(
        framealpha=0.0,
        ncols=2,
        fontsize="medium",
        bbox_to_anchor=(0.5, -0.03),
        loc="lower center",
    )
    plt.show()


plot_fields(dataset_ground_truth, dataset_train, dataset_latent_velocity)

@dataclass
class HelmholtzKernel(gpx.kernels.AbstractKernel):
    # initialise Phi and Psi kernels as any stationary kernel in gpJax
    potential_kernel: gpx.kernels.AbstractKernel = field(
        default_factory=lambda: gpx.kernels.RBF(active_dims=[0, 1])
    )
    stream_kernel: gpx.kernels.AbstractKernel = field(
        default_factory=lambda: gpx.kernels.RBF(active_dims=[0, 1])
    )

    def __call__(
        self, X: Float[Array, "1 D"], Xp: Float[Array, "1 D"]
    ) -> Float[Array, "1"]:
        # obtain indices for k_helm, implement in the correct sign between the derivatives
        z = jnp.array(X[2], dtype=int)
        zp = jnp.array(Xp[2], dtype=int)
        sign = (-1) ** (z + zp)

        # convert to array to correctly index, -ve sign due to exchange symmetry (only true for stationary kernels)
        potential_dvtve = -jnp.array(
            hessian(self.potential_kernel)(X, Xp), dtype=jnp.float64
        )[z][zp]
        stream_dvtve = -jnp.array(
            hessian(self.stream_kernel)(X, Xp), dtype=jnp.float64
        )[1 - z][1 - zp]

        return potential_dvtve + sign * stream_dvtve

# Redefine Gaussian process with Helmholtz kernel
kernel = HelmholtzKernel()
helmholtz_posterior = initialise_gp(kernel, mean, dataset_train)
# Optimise hyperparameters using BFGS
opt_helmholtz_posterior = optimise_mll(helmholtz_posterior, dataset_train)

# obtain latent distribution, extract x and y values over g
helmholtz_mean, helmholtz_std = latent_distribution(
    opt_helmholtz_posterior, dataset_ground_truth.X, dataset_train
)
dataset_latent_helmholtz = dataset_3d(pos_test, helmholtz_mean)

plot_fields(dataset_ground_truth, dataset_train, dataset_latent_helmholtz)

# ensure testing data alternates between x0 and x1 components
def nlpd(mean, std, vel_test):
    vel_query = jnp.column_stack((vel_test[0], vel_test[1])).flatten()
    normal = tfp.substrates.jax.distributions.Normal(loc=mean, scale=std)
    return -jnp.sum(normal.log_prob(vel_query))


# compute nlpd for velocity and helmholtz
nlpd_vel = nlpd(velocity_mean, velocity_std, vel_test)
nlpd_helm = nlpd(helmholtz_mean, helmholtz_std, vel_test)

print("NLPD for Velocity: %.2f \nNLPD for Helmholtz: %.2f" % (nlpd_vel, nlpd_helm))

%reload_ext watermark
%watermark -n -u -v -iv -w -a 'Ivan Shalashilin'