Welcome to GPJax!

GPJax is a didactic Gaussian process (GP) library in JAX, supporting GPU acceleration and just-in-time compilation. We seek to provide a flexible API to enable researchers to rapidly prototype and develop new ideas.

"Hello, GP!"

Typing GP models is as simple as the maths we would write on paper, as shown below.

import gpjax as gpx

mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.RBF()
prior = gpx.gps.Prior(mean_function = mean, kernel = kernel)
likelihood = gpx.likelihoods.Gaussian(num_datapoints = 123)

posterior = prior * likelihood

k(β‹…,β‹…β€²)=Οƒ2exp⁑(βˆ’βˆ₯β‹…βˆ’β‹…β€²βˆ₯222β„“2)p(f(β‹…))=GP(0,k(β‹…,β‹…β€²))p(yβ€‰βˆ£β€‰f(β‹…))=N(yβ€‰βˆ£β€‰f(β‹…),Οƒn2)p(f(β‹…)β€‰βˆ£β€‰y)∝p(f(β‹…))p(yβ€‰βˆ£β€‰f(β‹…)) . \begin{align} k(\cdot, \cdot') & = \sigma^2\exp\left(-\frac{\lVert \cdot- \cdot'\rVert_2^2}{2\ell^2}\right)\\ p(f(\cdot)) & = \mathcal{GP}(\mathbf{0}, k(\cdot, \cdot')) \\ p(y\,|\, f(\cdot)) & = \mathcal{N}(y\,|\, f(\cdot), \sigma_n^2) \\ \\ p(f(\cdot) \,|\, y) & \propto p(f(\cdot))p(y\,|\, f(\cdot))\,. \end{align}

Quick start


GPJax can be installed via pip. See our installation guide for further details.

pip install gpjax


New to GPs? Then why not check out our introductory notebook that starts from Bayes' theorem and univariate Gaussian distributions.


Looking for a good place to start? Then why not begin with our regression notebook.

Citing GPJax

If you use GPJax in your research, please cite our JOSS paper.

