Since v0.9, GPJax is built upon Flax's
NNX module. This transition
allows for more efficient parameter handling, improved integration with Flax and
Flax-based libraries, and enhanced flexibility in model design. This notebook provides
a high-level overview of the backend module design in GPJax. For an introduction to
NNX, please refer to the official
documentation.
importtypingastpfromflaximportnnx# Enable Float64 for more stable matrix inversions.fromjaximport(config,grad,)importjax.numpyasjnpimportjax.tree_utilasjtufromjaxtypingimport(Float,Num,install_import_hook,)importmatplotlibasmplimportmatplotlib.pyplotaspltfromexamples.utilsimportuse_mpl_stylefromgpjax.mean_functionsimport(AbstractMeanFunction,Constant,)fromgpjax.parametersimport(DEFAULT_BIJECTION,Parameter,PositiveReal,Real,transform,)fromgpjax.typingimport(Array,ScalarFloat,)config.update("jax_enable_x64",True)withinstall_import_hook("gpjax","beartype.beartype"):importgpjaxasgpx# set the default style for plottinguse_mpl_style()cols=mpl.rcParams["axes.prop_cycle"].by_key()["color"]
Parameters
The biggest change bought about by the transition to an NNX backend is the increased
support we now provide for handling parameters. As discussed in our Sharp Bits -
Bijectors Doc, GPJax
uses bijectors to transform constrained parameters to unconstrained parameters during
optimisation. You may now register the support of a parameter using our Parameter
class. To see this, consider the constant mean function who contains a single constant
parameter whose value ordinarily exists on the real line. We can register this
parameter as follows:
However, suppose you wish your mean function's constant parameter to be strictly
positive. This is easy to achieve by using the correct Parameter type which, in this
case, will be the PositiveReal. However, any Parameter that subclasses from
Parameter will be transformed by GPJax.
issubclass(PositiveReal,Parameter)
True
Injecting this newly constrained parameter into our mean function is then identical to before.
value needs to be positive, got -1.0 (`check` failed)
Parameter Transforms
With a parameter instantiated, you likely wish to transform the parameter's value from
its constrained support onto the entire real line. To do this, you can apply the
transform function to the parameter. To control the bijector used to transform the
parameter, you may pass a set of bijectors into the transform function.
Under-the-hood, the transform function is looking up the bijector of a parameter
using it's _tag field in the bijector dictionary, and then applying the bijector to
the parameter's value using a tree map operation.
print(constant_param.tag)
positive
For most users, you will not need to worry about this as we provide a set of default
bijectors that are defined for all the parameter types we support. However, see our
Kernel Guide
Notebook to
see how you can define your own bijectors and parameter types.
print(DEFAULT_BIJECTION[constant_param.tag])
<numpyro.distributions.transforms.SoftplusTransform object at 0x7fe8e5e69590>
We see here that the Softplus bijector is specified as the default for strictly
positive parameters. To apply this, we must first realise the state of our model.
This is achieved using the split function provided by nnx.
The parameter's value was changed here from 1. to 0.54132485. This is the result of
applying the Softplus bijector to the parameter's value and projecting its value onto
the real line. Were the parameter's value to be closer to 0, then the transformation
would be more pronounced.
In the above, we transformed a single parameter. However, in practice your parameters
may be nested within several functions e.g., a kernel function within a GP model.
Fortunately, transforming several parameters is a simple operation that we here
demonstrate for a conjugate GP posterior (see our Regression
Notebook for detailed
explanation of this model.).
Now contained within the posterior PyGraph here there are four parameters: the
kernel's lengthscale and variance, the noise variance of the likelihood, and the
constant of the mean function. Using NNX, we may realise these parameters through the
nnx.split function. The split function deomposes a PyGraph into a GraphDef and
State object. As the name suggests, State contains information on the parameters'
state, whilst GraphDef contains the information required to reconstruct a PyGraph
from a give State.
The State object behaves just like a PyTree and, consequently, we may use JAX's
tree_map function to alter the values of the State. The updated State can then
be used to reconstruct our posterior. In the below, we simply increment each
parameter's value by 1.
However, we begun this point of conversation with bijectors in mind, so let us now see
how bijectors may be applied to a collection of parameters in GPJax. Fortunately, this
is very straightforward, and we may simply use the transform function as before.
One of the advantages of being able to split and re-merge the PyGraph is that we are
able to gain fine-scale control over the parameters' whose state we wish to realise.
This is by virtue of the fact that each of our parameters now inherit from
gpjax.parameters.Parameter. In the former, we were simply extracting any
Parametersubclass from the posterior. However, suppose we only wish to extract those
parameters whose support is the positive real line. This is easily achieved by
altering the way in which we invoke nnx.split.
Now we see that we have two state objects: one containing the positive real parameters
and the other containing the remaining parameters. This functionality is exceptionally
useful as it allows us to efficiently operate on a subset of the parameters whilst
leaving the others untouched. Looking forward, we hope to use this functionality in
our Variational Inference
Approximations to
perform more efficient updates of the variational parameters and then the model's
hyperparameters.
NNX Modules
To conclude this notebook, we will now demonstrate the ease of use and flexibility
offered by NNX modules. To do this, we will implement a linear mean function using the
existing abstractions in GPJax.
For inputs xnββRd, the linear mean function m(x):RdβR is defined as:
m(x)=Ξ±+i=1βdβΞ²iβxiβ
where Ξ±βR and Ξ²iββR are the parameters of the
mean function. Let's now implement that using the new NNX backend.
classLinearMeanFunction(AbstractMeanFunction):def__init__(self,intercept:tp.Union[ScalarFloat,Float[Array," O"],Parameter]=0.0,slope:tp.Union[ScalarFloat,Float[Array," D O"],Parameter]=0.0,):ifisinstance(intercept,Parameter):self.intercept=interceptelse:self.intercept=Real(jnp.array(intercept))ifisinstance(slope,Parameter):self.slope=slopeelse:self.slope=Real(jnp.array(slope))def__call__(self,x:Num[Array,"N D"])->Float[Array,"N O"]:returnself.intercept[...]+jnp.dot(x,self.slope[...])
As we can see, the implementation is straightforward and concise. The
AbstractMeanFunction module is a subclass of nnx.Module and may, therefore, be
used in any split or merge call. Further, we have registered the intercept and
slope parameters as Real parameter types. This registers their value in the PyGraph
and means that they will be part of any operation applied to the PyGraph e.g.,
transforming and differentiation.
To check our implementation worked, let's now plot the value of our mean function for
a linearly spaced set of inputs.
We'll compute derivatives of the conjugate marginal log-likelihood, with respect to
the unconstrained state of the kernel, mean function, and likelihood parameters.
In practice, you would wish to perform multiple iterations of gradient descent to
learn the optimal parameter values. However, for the purposes of illustration, we use
another tree_map in the below to update the parameters' state using their previously
computed gradients. As you can see, the really beauty in having access to the model's
state is that we have full control over the operations that we perform to the state.
Now we will plot the updated mean function alongside its initial form. To achieve
this, we first merge the state back into the model using merge, and we then simply
invoke the model as normal.
optimised_posterior=nnx.merge(graphdef,optimised_params,*others)fig,ax=plt.subplots()ax.plot(X,optimised_posterior.prior.mean_function(X),label="Updated mean function")ax.plot(X,meanf(X),label="Initial mean function")ax.legend()ax.set(xlabel="x",ylabel="m(x)")
[Text(0.5, 0, 'x'), Text(0, 0.5, 'm(x)')]
Conclusions
In this notebook we have explored how GPJax's Flax-based backend may be easily
manipulated and extended. For a more applied look at this, see how we construct a
kernel on polar coordinates in our Kernel
Guide
notebook.