Skip to content

Welcome to GPJax

GPJax is a didactic Gaussian process (GP) library in JAX, supporting GPU acceleration and just-in-time compilation. We seek to provide a flexible API to enable researchers to rapidly prototype and develop new ideas.

Gaussian process posterior.

"Hello, GP!"

Typing GP models is as simple as the maths we would write on paper, as shown below.

import gpjax as gpx

mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.RBF()
prior = gpx.gps.Prior(mean_function = mean, kernel = kernel)
likelihood = gpx.likelihoods.Gaussian(num_datapoints = 123)

posterior = prior * likelihood

k(β‹…,β‹…β€²)=Οƒ2exp⁑(βˆ’βˆ₯β‹…βˆ’β‹…β€²βˆ₯222β„“2)p(f(β‹…))=GP(0,k(β‹…,β‹…β€²))p(yβ€‰βˆ£β€‰f(β‹…))=N(yβ€‰βˆ£β€‰f(β‹…),Οƒn2)p(f(β‹…)β€‰βˆ£β€‰y)∝p(f(β‹…))p(yβ€‰βˆ£β€‰f(β‹…)) . \begin{align} k(\cdot, \cdot') & = \sigma^2\exp\left(-\frac{\lVert \cdot- \cdot'\rVert_2^2}{2\ell^2}\right)\\ p(f(\cdot)) & = \mathcal{GP}(\mathbf{0}, k(\cdot, \cdot')) \\ p(y\,|\, f(\cdot)) & = \mathcal{N}(y\,|\, f(\cdot), \sigma_n^2) \\ \\ p(f(\cdot) \,|\, y) & \propto p(f(\cdot))p(y\,|\, f(\cdot))\,. \end{align}

We currently have some availability for consulting on how Gaussian processes, Bayesian modelling, and GPJax can be integrated into your team's work. If this sounds relevant to your work, book an introductory call. These calls are for consulting inquiries only. For technical usage questions and free community support, please use GitHub Discussions and the documentation below.

Citing GPJax

If you use GPJax in your research, please cite our JOSS paper.

@article{Pinder2022,
  doi = {10.21105/joss.04455},
  url = {https://doi.org/10.21105/joss.04455},
  year = {2022},
  publisher = {The Open Journal},
  volume = {7},
  number = {75},
  pages = {4455},
  author = {Thomas Pinder and Daniel Dodd},
  title = {GPJax: A Gaussian Process Framework in JAX},
  journal = {Journal of Open Source Software}
}