Migration guide: 0.13.x β 0.14.0
GPJax 0.14 replaces the Flax NNX backend with
Equinox +
paramax, and introduces a linear-algebra layer via
Lineax. It also removes the custom bijector
stack in favour of
numpyro constraints.
The changes are mostly internal. They surface in three places:
- How you define custom modules and custom parameter classes.
- How you read a parameter value (
param.unwrap()/paramax.unwrap(model)instead ofparam.value). - How you freeze parameters (
paramax.non_trainable(...)instead of thetrainable=filter argument onfit).
If you only use the high-level API (gpx.Prior, gpx.Posterior, gpx.fit,
etc.) most code keeps working once you update the two or three call sites
below.
Installation
New dependencies (pulled in automatically): equinox>=0.11, paramax>=0.0.5.
Flax is no longer a runtime dependency.
Breaking changes
1. Backend: flax.nnx.Module β equinox.Module
If you subclassed nnx.Module to build a custom model, kernel, mean function,
likelihood, or variational family, change the base class:
# Before (0.13.x)
from flax import nnx
class MyKernel(nnx.Module):
def __init__(self, lengthscale):
self.lengthscale = gpx.parameters.PositiveReal(lengthscale)
# After (0.14.0)
import equinox as eqx
class MyKernel(eqx.Module):
lengthscale: gpx.parameters.PositiveReal
def __init__(self, lengthscale):
self.lengthscale = gpx.parameters.PositiveReal(lengthscale)
Equinox requires class-level field annotations for every attribute, and
static configuration fields should be marked with eqx.field(static=True).
2. Parameter classes are now paramax.AbstractUnwrappable
PositiveReal, NonNegativeReal, Real, SigmoidBounded, and
LowerTriangular all live in gpjax.parameters with the same names, but they
now inherit from paramax.AbstractUnwrappable and store their value in an
unconstrained internal field. The constraining bijection is applied at
read time through unwrap().
# Before (0.13.x) β nnx.Variable-based
length = gpx.parameters.PositiveReal(0.5)
length.value # -> 0.5
length.value = 1.0 # in-place mutation (nnx)
# After (0.14.0) β paramax.AbstractUnwrappable
length = gpx.parameters.PositiveReal(0.5)
length.unwrap() # -> 0.5 (applies softplus to the stored unconstrained value)
# Unwrap an entire model tree in one call:
import paramax
model_resolved = paramax.unwrap(model)
Parameter (the old generic base class), DEFAULT_BIJECTION, the
transform(...) function, and FillTriangularTransform have been
removed. numpyro.distributions.biject_to now handles every
constraint β bijection mapping.
3. LowerTriangular now requires a valid Cholesky factor
Previously LowerTriangular accepted any lower-triangular matrix (the
diagonal was unconstrained). It is now parameterised via
numpyro.distributions.constraints.softplus_lower_cholesky, so the diagonal
must be strictly positive.
- Passing a matrix with zero or negative diagonal entries produces
NaNduring construction (inv_softplusof a non-positive number). - In-library usage is unaffected: the only consumer is
VariationalGaussian.variational_root_covariance, which is initialised to the identity by default. - If you previously supplied a custom
variational_root_covariance, ensure its diagonal is strictly positive. Under the old parameterisation, zero or negative diagonals produced singular or sign-ambiguous variational covariances.
4. gpx.fit / fit_scipy / fit_lbfgs: removed params_bijection and trainable
Bijection handling is now automatic via paramax.unwrap inside the loss
function, and freezing parameters is expressed by wrapping them in
paramax.non_trainable:
# Before (0.13.x)
opt_model, history = gpx.fit(
model=posterior,
objective=gpx.objectives.conjugate_mll,
train_data=D,
optim=ox.adam(1e-2),
params_bijection=gpx.parameters.DEFAULT_BIJECTION,
trainable=gpx.parameters.Parameter, # filter-based trainability
)
# After (0.14.0)
import paramax
# Freeze specific parameters up-front by wrapping them:
posterior = eqx.tree_at(
lambda m: m.prior.kernel.lengthscale,
posterior,
replace_fn=paramax.non_trainable,
)
opt_model, history = gpx.fit(
model=posterior,
objective=gpx.objectives.conjugate_mll,
train_data=D,
optim=ox.adam(1e-2),
)
Internally, fit now splits the model with eqx.partition(model, eqx.is_array)
so only concrete JAX arrays participate in the gradient update; everything
wrapped in paramax.non_trainable is held constant.
5. register_parameters removed
The gpx.parameters.register_parameters decorator (added in 0.13.x to mark
NNX variables as GPJax parameters) is gone. With Equinox, GPJax identifies
parameter classes through isinstance checks on AbstractUnwrappable, so
registration is unnecessary.
6. gpjax.linalg rewrite: cola β Lineax
Kernel gram() now returns a lineax.AbstractLinearOperator
(typically lineax.MatrixLinearOperator) instead of a cola.LinearOperator.
Materialise with .as_matrix().
The following names have been removed from gpjax.linalg:
| Removed | Replacement |
|---|---|
PSD, psd |
Not needed β Lineax operators carry tags directly. |
Dense, Diagonal, Identity, Triangular |
lineax.MatrixLinearOperator, lineax.DiagonalLinearOperator, lineax.IdentityLinearOperator, lineax.TriangularLinearOperator |
LinearOperator |
lineax.AbstractLinearOperator |
diag, solve |
operator.diagonal(), lineax.linear_solve(...) |
lower_cholesky |
gpjax.linalg.cholesky_factor (singledispatch, returns a lower-triangular operator) |
BlockDiag, Kronecker, and logdet are unchanged. Use
gpjax.linalg.add_jitter to add a jitter term to a covariance operator.
7. Custom bijectors replaced with numpyro constraints
If you had a custom Parameter subclass that declared a bijection, replace
the bijection with a numpyro constraint and use biject_to:
# Before (0.13.x)
class MyParam(gpx.parameters.Parameter):
# Custom bijection registered via DEFAULT_BIJECTION
...
# After (0.14.0)
from numpyro.distributions import biject_to, constraints
from paramax import AbstractUnwrappable
import jax
class MyParam(AbstractUnwrappable):
_constraint = constraints.positive
_unconstrained: jax.Array
def __init__(self, value):
self._unconstrained = biject_to(self._constraint).inv(value)
def unwrap(self):
return biject_to(self._constraint)(self._unconstrained)
Non-breaking cleanup
__description__changed from"Gaussian processes in JAX and Flax"to"Gaussian processes in JAX", since Flax is no longer a dependency.- Many kernel
compute_engineinternals moved; the publickernel(x, y),kernel.gram(x),kernel.cross_covariance(x, y), andkernel.diagonal(x)methods are unchanged.
Upgrade checklist
- [ ] Replace
nnx.Modulebase classes witheqx.Module, and add class-level type annotations for every field. - [ ] Replace
param.valuereads withparam.unwrap(), or callparamax.unwrap(model)once at the top of your loss / prediction function. - [ ] Drop any
params_bijection=/trainable=arguments passed togpx.fit. To freeze parameters, wrap them withparamax.non_trainableusingeqx.tree_at. - [ ] Remove any
gpx.parameters.register_parametersdecorator calls. - [ ] If you construct
LowerTriangularfrom a custom matrix, verify the diagonal is strictly positive. - [ ] If you used
gpjax.linalgoperators directly, switch to the Lineax equivalents listed above.
Reporting issues
This is a pre-release (0.14.0rc1). Please file migration issues at
https://github.com/thomaspinder/GPJax/issues with the 0.14-migration
label so they can be triaged before the stable 0.14.0 release.