In this notebook we demonstrate how GPJax can be used in conjunction with
Flax to build deep kernel Gaussian
processes. Modelling data with discontinuities is a challenging task for regular
Gaussian process models. However, as shown in
, transforming the inputs to our
Gaussian process model's kernel through a neural network can offer a solution to this.
fromdataclassesimport(dataclass,field,)fromflaximportnnximportjax# Enable Float64 for more stable matrix inversions.fromjaximportconfigimportjax.numpyasjnpimportjax.randomasjrfromjaxtypingimport(Array,Float,install_import_hook,)importmatplotlibasmplimportmatplotlib.pyplotaspltimportoptaxasoxfromscipy.signalimportsawtoothfromexamples.utilsimportuse_mpl_stylefromgpjax.kernels.computationsimport(AbstractKernelComputation,DenseKernelComputation,)config.update("jax_enable_x64",True)withinstall_import_hook("gpjax","beartype.beartype"):importgpjaxasgpxfromgpjax.kernels.baseimportAbstractKernelfromgpjax.parametersimport(Parameter,)# set the default style for plottinguse_mpl_style()cols=mpl.rcParams["axes.prop_cycle"].by_key()["color"]key=jr.key(42)
Dataset
As previously mentioned, deep kernels are particularly useful when the data has
discontinuities. To highlight this, we will use a sawtooth function as our data.
Instead of applying a kernel k(β ,β β²) directly on some data, we seek to
apply a feature mapΟ(β ) that projects the data to learn more meaningful
representations beforehand. In deep kernel learning, Ο is a neural network
whose parameters are learned jointly with the GP model's hyperparameters. The
corresponding kernel is then computed by k(Ο(β ),Ο(β β²)). Here
k(β ,β β²) is referred to as the base kernel.
Implementation
Although deep kernels are not currently supported natively in GPJax, defining one is
straightforward as we now demonstrate. Inheriting from the base AbstractKernel
in GPJax, we create the DeepKernelFunction object that allows the
user to supply the neural network and base kernel of their choice. Kernel matrices
are then computed using the regular gram and cross_covariance functions.
With a deep kernel object created, we proceed to define a neural network. Here we
consider a small multi-layer perceptron with two linear hidden layers and ReLU
activation functions between the layers. The first hidden layer contains 64 units,
while the second layer contains 32 units. Finally, we'll make the output of our
network a three units wide. The corresponding kernel that we define will then be of
ARD form
to allow for different lengthscales in each dimension of the feature space.
Users may wish to design more intricate network structures for more complex tasks,
which functionality is supported well in Haiku.
We train our model via maximum likelihood estimation of the marginal log-likelihood.
The parameters of our neural network are learned jointly with the model's
hyperparameter set.
With the inclusion of a neural network, we take this opportunity to highlight the
additional benefits gleaned from using
Optax for optimisation. In particular, we
showcase the ability to use a learning rate scheduler that decays the optimiser's
learning rate throughout the inference. We decrease the learning rate according to a
half-cosine curve over 700 iterations, providing us with large step sizes early in
the optimisation procedure before approaching more conservative values, ensuring we
do not step too far. We also consider a linear warmup, where the learning rate is
increased from 0 to 1 over 50 steps to get a reasonable initial learning rate value.
schedule=ox.warmup_cosine_decay_schedule(init_value=0.0,peak_value=0.01,warmup_steps=75,decay_steps=700,end_value=0.0,)optimiser=ox.chain(ox.clip(1.0),ox.adamw(learning_rate=schedule),)# Train all parameters (default behavior with trainable=Parameter)# Alternative options for selective training:# - trainable=PositiveReal # only train positive parameters# - trainable=lambda module, path, value: 'kernel' in path # only kernel paramsopt_posterior,history=gpx.fit(model=posterior,objective=lambdap,d:-gpx.objectives.conjugate_mll(p,d),train_data=D,optim=optimiser,num_iters=800,key=key,trainable=Parameter,# explicitly specify trainable filter (default))
0%| | 0/800 [00:00<?, ?it/s]
Prediction
With a set of learned parameters, the only remaining task is to predict the output
of the model. We can do this by simply applying the model to a test data set.