Skip to content

Joint Inference with Numpyro

In this notebook, we demonstrate how to use Numpyro to perform fully Bayesian inference over the hyperparameters of a Gaussian process model. We will look at a scenario where we have a structured mean function in the form of a linear model, and a GP capturing the residuals. We will infer the parameters of both the linear model and the GP jointly.

from examples.utils import use_mpl_style
import gpjax as gpx
from jax import config
import jax.numpy as jnp
import jax.random as jr
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
from numpyro.infer import (
    MCMC,
    NUTS,
    Predictive,
)

config.update("jax_enable_x64", True)

use_mpl_style()
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]

key = jr.key(123)
keys = jr.split(key, 4)

Data Generation

We generate a synthetic dataset that consists of a linear trend together with a locally periodic residual signal whose amplitude varies over time, an additional high-frequency component, and a local bump. This data generating process is purposefully designed to illustrate the benefit of incorporating a Gaussian process into a larger Bayesian model; however, such structures are common.

N = 200

x = jnp.sort(jr.uniform(keys[0], shape=(N, 1), minval=0.0, maxval=10.0), axis=0)

# True parameters for the linear trend
true_slope = 0.45
true_intercept = 1.5

# Structured residual signal captured by the GP
slow_period = 6.0
fast_period = 0.8
amplitude_envelope = 1.0 + 0.5 * jnp.sin(2 * jnp.pi * x / slow_period)
modulated_periodic = amplitude_envelope * jnp.sin(2 * jnp.pi * x / fast_period)
high_frequency_component = 0.3 * jnp.cos(2 * jnp.pi * x / 0.35)
localised_bump = 1.2 * jnp.exp(-0.5 * ((x - 7.0) / 0.45) ** 2)

linear_trend = true_slope * x + true_intercept
residual_signal = modulated_periodic + high_frequency_component + localised_bump
signal = linear_trend + residual_signal

# Observations with homoscedastic noise
observation_noise = 0.7
y = signal + observation_noise * jr.normal(keys[1], shape=x.shape)

D = gpx.Dataset(X=x, y=y)

fig, ax = plt.subplots()
ax.plot(x, y, "o", label="Observations", color=cols[0])
ax.plot(x, signal, "--", label="True Signal", color=cols[1])
ax.legend()
<matplotlib.legend.Legend at 0x7f2167b89850>

png

Model Definition

We define a GP model with a zero mean function, as we will handle the linear trend explicitly in the Numpyro model. Naturally, one could parameterise the GP with a linear mean function; however, this design is purely pedagogical. For the kernel, we specify a product of a periodic kernel and an RBF kernel. This choice reflects our prior knowledge that the signal is locally periodic. For a more in-depth look at how complex kernels can be designed, see our Introduction to Kernels notebook.

We see in the below that priors are specified on the parameters' constrained space. For example, the lengthscale parameter must be strictly positive and, therefore, a unit-Gaussian would be a poor choice of prior. Instead, we opt for the log-Gaussian as the prior distribution as its support matches that of our lengthscale parameter. Priors are standard NumPyro distributions sampled directly inside the model function with numpyro.sample.

# Priors are defined as NumPyro distributions and sampled directly inside the model
# function below. GPJax parameter constructors accept raw JAX arrays from
# numpyro.sample, so no special registration step is needed.

We'll construct the Gaussian process inside the NumPyro model function, passing sampled hyperparameters directly to the GPJax constructors. For a deeper look at how GP construction works, see our Regression notebook.

Joint Inference Loop

We define a NumPyro model that samples all parameters directly using numpyro.sample, builds the GPJax posterior from those samples, and scores it with the conjugate marginal log-likelihood via numpyro.factor. No special registration step is needed -- GPJax constructors accept raw JAX arrays returned by numpyro.sample.

def model(X, Y, X_new=None):
    slope = numpyro.sample("slope", dist.Normal(0.0, 2.0))
    intercept = numpyro.sample("intercept", dist.Normal(0.0, 2.0))
    linear_component = slope * X + intercept

    residuals = Y - linear_component

    lengthscale = numpyro.sample("lengthscale", dist.LogNormal(0.0, 1.0))
    variance = numpyro.sample("variance", dist.LogNormal(0.0, 1.0))
    period = numpyro.sample("period", dist.LogNormal(0.0, 0.5))
    obs_noise = numpyro.sample("obs_noise", dist.LogNormal(0.0, 1.0))

    stationary_component = gpx.kernels.RBF(
        lengthscale=lengthscale, variance=variance
    )
    periodic_component = gpx.kernels.Periodic(
        lengthscale=lengthscale, period=period
    )
    kernel = stationary_component * periodic_component

    meanf = gpx.mean_functions.Constant()
    prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
    likelihood = gpx.likelihoods.Gaussian(num_datapoints=N, obs_stddev=obs_noise)
    posterior = prior * likelihood

    D_resid = gpx.Dataset(X=X, y=residuals)
    mll = gpx.objectives.conjugate_mll(posterior, D_resid)

    numpyro.factor("gp_log_lik", mll)

    if X_new is not None:
        latent_dist = posterior.predict(X_new, train_data=D_resid)
        f_new = numpyro.sample("f_new", latent_dist)
        f_new = f_new.reshape((-1, 1))

        y_noise = numpyro.sample(
            "y_noise",
            dist.Normal(0.0, obs_noise).expand(f_new.shape).to_event(f_new.ndim),
        )

        total_prediction = slope * X_new + intercept + f_new + y_noise
        numpyro.deterministic("y_pred", total_prediction)
        return total_prediction

Running MCMC

Using Numpyro's NUTS sampler, we can now draw samples from the posterior. To ensure our documentation can be quickly built, we limit the number of samples and the length of the burn-in phase below. However, in practice, one should draw more samples from multiple chains using the num_chains argument in the MCMC constructor.

nuts_kernel = NUTS(model)
# In practice, one should run more samples from multiple chains.
mcmc = MCMC(
    nuts_kernel,
    num_warmup=500,
    num_samples=500,
)
mcmc.run(keys[2], x, y)
mcmc.print_summary()
                   mean       std    median      5.0%     95.0%     n_eff     r_hat
    intercept      0.58      1.50      0.59     -1.71      3.11    229.89      1.00
  lengthscale      2.46      0.79      2.27      1.37      3.74    199.46      1.00
    obs_noise      0.79      0.05      0.79      0.71      0.86    425.82      1.00
       period      0.80      0.02      0.80      0.77      0.83    371.64      1.00
        slope      0.51      0.31      0.52     -0.03      0.97    224.30      1.00
     variance      7.86      6.07      6.21      1.26     15.67    223.51      1.00

Number of divergences: 0

Analysis and Plotting

Having obtained samples from the posterior, we now evaluate the predictive posterior at the test sites. In our Poisson Regression, this process is done manually. However, by virtue of using Numpyro here, we may instead use Numpyro's Predictive object to handle this process for us. Once samples are drawn from the predictive posterior distribution, we may evaluate the mean and 95% credible interval and compare our model's predictions to the underlying data.

samples = mcmc.get_samples()
predictive = Predictive(
    model,
    posterior_samples=samples,
    return_sites=["y_pred"],
)

x_test = jnp.linspace(-0.5, 10.5, 200).reshape(-1, 1)
predictions = predictive(keys[3], x, y, X_new=x_test)
y_pred = predictions["y_pred"]

mean_prediction = jnp.mean(y_pred, axis=0)
lower, upper = jnp.percentile(y_pred, jnp.array([2.5, 97.5]), axis=0)

fig, ax = plt.subplots()
ax.scatter(x, y, alpha=0.5, label="Observations", color=cols[0])
ax.plot(x, signal, "--", label="True Signal", color=cols[0])

ax.plot(x_test, mean_prediction, "-", label="Posterior Mean", color=cols[1])
ax.fill_between(
    x_test.flatten(),
    lower.flatten(),
    upper.flatten(),
    color=cols[1],
    alpha=0.2,
    label="95% Credible Interval",
)
ax.legend()
<matplotlib.legend.Legend at 0x7f215499c1d0>

png

Conclusions

This concludes our introduction to the integration of GPJax with Numpyro. The presentation given here is designed to best illustrate how the two libraries integrate. For a closer look at the more complex models that one may build by integrating Numpyro and GPJax, see our Spatial Semi-Linear Model notebook.

System configuration

%load_ext watermark
%watermark -n -u -v -iv -w -a "Thomas Pinder"
Author: Thomas Pinder

Last updated: Fri, 01 May 2026

Python implementation: CPython
Python version       : 3.11.15
IPython version      : 9.9.0

gpjax     : 0.14.0
jax       : 0.9.0
matplotlib: 3.10.8
numpyro   : 0.19.0

Watermark: 2.6.0

We currently have some availability for consulting on how Gaussian processes, Bayesian modelling, and GPJax can be integrated into your team's work. If this sounds relevant to your work, book an introductory call. These calls are for consulting inquiries only. For technical usage questions and free community support, please use GitHub Discussions and the documentation below.