In this guide, we introduce the kernels available in GPJax and demonstrate how to
create custom kernels.
# Enable Float64 for more stable matrix inversions.fromjaximportconfigfromjax.nnimportsoftplusimportjax.numpyasjnpimportjax.randomasjrfromjaxtypingimport(Array,Float,install_import_hook,)importmatplotlib.pyplotaspltfromnumpyro.distributionsimportconstraintsimportnumpyro.distributions.transformsasnptfromexamples.utilsimportuse_mpl_stylefromgpjax.kernels.computationsimportDenseKernelComputationfromgpjax.parametersimport(DEFAULT_BIJECTION,PositiveReal,)config.update("jax_enable_x64",True)withinstall_import_hook("gpjax","beartype.beartype"):importgpjaxasgpxfromgpjax.parametersimportParameter# set the default style for plottinguse_mpl_style()key=jr.key(42)cols=plt.rcParams["axes.prop_cycle"].by_key()["color"]
Supported Kernels
The following kernels are natively supported in GPJax.
While the syntax is consistent, each kernel's type influences the
characteristics of the sample paths drawn. We visualise this below with 10
function draws per kernel.
By default, kernels operate over every dimension of the supplied inputs. In
some use cases, it is desirable to restrict kernels to specific dimensions of
the input data. We can achieve this by the active dims argument, which
determines which input index values the kernel evaluates.
To see this, consider the following 5-dimensional dataset for which we would
like our RBF kernel to act on the first, second and fourth dimensions.
We'll now simulate some data and evaluate the kernel on the previously selected
input dimensions.
# Inputsx_matrix=jr.normal(key,shape=(50,5))# Compute the Gram matrixK=slice_kernel.gram(x_matrix)print(K.shape)
(50, 50)
Kernel combinations
The product or sum of two positive definite matrices yields a positive
definite matrix. Consequently, summing or multiplying sets of kernels is a
valid operation that can give rich kernel functions. In GPJax, functionality for
a sum kernel is provided by the SumKernel class.
GPJax makes the process of implementing kernels of your choice straightforward
with two key steps:
Listing the kernel's parameters.
Defining the kernel's pairwise operation.
We'll demonstrate this process now for a circular kernel - an adaption of
the excellent guide given in the PYMC3 documentation. We encourage curious
readers to visit their notebook
here.
Circular kernel
When the underlying space is polar, typical Euclidean kernels such as Matérn
kernels are insufficient at the boundary where discontinuities will present
themselves.
This is due to the fact that for a polar space ∣0,2π∣=0 i.e.,
the space wraps. Euclidean kernels have no mechanism in them to represent this
logic and will instead treat 0 and 2π and elements far apart. Circular
kernels do not exhibit this behaviour and instead wrap around the boundary
points to create a smooth function. Such a kernel was given in Padonou &
Roustant (2015) where any two angles
θ and θ′ are written as
Here the hyperparameter τ is analogous to a lengthscale for Euclidean
stationary kernels, controlling the correlation between pairs of observations.
While d is an angular distance metric
d(θ,θ′)=∣(θ−θ′+c)mod2c−c∣.
To implement this, one must write the following class.
defangular_distance(x,y,c):returnjnp.abs((x-y+c)%(c*2)-c)classShiftedSoftplusTransform(npt.ParameterFreeTransform):r""" Transform from unconstrained space to the domain [4, infinity) via :math:`y = 4 + \log(1 + \exp(x))`. The inverse is computed as :math:`x = \log(\exp(y - 4) - 1)`. """domain=constraints.realcodomain=constraints.interval(4.0,jnp.inf)# updated codomaindef__call__(self,x):return4.0+softplus(x)# shift the softplus output by 4def_inverse(self,y):returnnpt._softplus_inv(y-4.0)# subtract the shift in the inversedeflog_abs_det_jacobian(self,x,y,intermediates=None):return-softplus(-x)DEFAULT_BIJECTION["polar"]=ShiftedSoftplusTransform()classPolar(gpx.kernels.AbstractKernel):period:floattau:PositiveRealdef__init__(self,tau:float=5.0,period:float=2*jnp.pi,active_dims:list[int]|slice|None=None,n_dims:int|None=None,):super().__init__(active_dims,n_dims,DenseKernelComputation())self.period=jnp.array(period)self.tau=PositiveReal(jnp.array(tau),tag="polar")def__call__(self,x:Float[Array,"1 D"],y:Float[Array,"1 D"])->Float[Array,"1"]:c=self.period/2.0t=angular_distance(x,y,c)K=(1+self.tau[...]*t/c)*jnp.clip(1-t/c,0,jnp.inf)**self.tau[...]returnK.squeeze()
We unpack this now to make better sense of it. In the kernel's initialiser
we specify the length of a single period. As the underlying
domain is a circle, this is 2π. We then define the kernel's __call__
function which is a direct implementation of Equation (1) where we define c
as half the value of period.
To constrain τ to be greater than 4, we use a Softplus bijector with a
clipped lower bound of 4.0. This is done by specifying the bijector argument
when we define the parameter field.
Using our polar kernel
We proceed to fit a GP with our custom circular kernel to a random sequence of
points on a circle (see the
Regression notebook
for further details on this process).