In this notebook, we will walk users through the process of creating a new likelihood
in GPJax.
Background
In this section we'll provide a short introduction to likelihoods and why they are
important. For users who are already familiar with likelihoods, feel free to skip to
the next section, and for users who would like more information than is provided
here, please see our introduction to Gaussian processes notebook.
What is a likelihood?
We adopt the notation of our
introduction to Gaussian processes notebook where we have a
Gaussian process (GP) f(β )βΌGP(m(β ),k(β ,β )) and a
dataset y={ynβ}n=1Nβ observed at corresponding inputs
x={xnβ}n=1Nβ. The evaluation of f at x is denoted by
f={f(xnβ)}n=1Nβ. The likelihood function of the GP is then given
by
$$
\begin{align}
\label{eq:likelihood_fn}
p(\mathbf{y}\mid \mathbf{f}) = \prod_{n=1}^N p(y_n\mid f(x_n))\,.
\end{align}
$$
Conceptually, this conditional distribution describes the probability of the observed
data, conditional on the latent function values.
Why is the likelihood important?
Choosing the correct likelihood function when building a GP, or any Bayesian model
for that matter, is crucial. The likelihood function encodes our assumptions about
the data and the noise that we expect to observe. For example, if we are modelling
air pollution, then we would not expect to observe negative values of pollution. In
this case, we would choose a likelihood function that is only defined for positive
values. Similarly, if our data is the proportion of people who voted for a particular
political party, then we would expect to observe values between 0 and 1. In this
case, we would choose a likelihood function that is only defined for values between
0 and 1.
Likelihoods in GPJax
In GPJax, all likelihoods are a subclass of the AbstractLikelihood class. This base
abstract class contains the three core methods that all likelihoods must implement:
predict, link_function, and expected_log_likelihood. We will discuss each of
these methods in the forthcoming sections, but first, we will show how to instantiate
a likelihood object. To do this, we'll need a dataset.
importjax# Enable Float64 for more stable matrix inversions.fromjaximportconfigimportjax.numpyasjnpimportjax.randomasjrimportmatplotlib.pyplotaspltfromexamples.utilsimportuse_mpl_styleimportgpjaxasgpxconfig.update("jax_enable_x64",True)# set the default style for plottinguse_mpl_style()cols=plt.rcParams["axes.prop_cycle"].by_key()["color"]key=jr.key(42)n=50x=jnp.sort(jr.uniform(key=key,shape=(n,1),minval=-3.0,maxval=3.0),axis=0)xtest=jnp.linspace(-3,3,100)[:,None]f=lambdax:jnp.sin(x)y=f(x)+0.1*jr.normal(key,shape=x.shape)D=gpx.Dataset(x,y)fig,ax=plt.subplots()ax.plot(x,y,"o",label="Observations")ax.plot(x,f(x),label="Latent function")ax.legend()
<matplotlib.legend.Legend at 0x7f64ac4d4850>
In this example, our observations have support [β3,3] and are generated from a
sinusoidal function with Gaussian noise. As such, our response values y
range between β1 and 1, subject to Gaussian noise. Due to this, a Gaussian
likelihood is appropriate for this dataset as it allows for negative values.
As we see in \eqref{eq:likelihood_fn}, the likelihood function factorises over the
n observations. As such, we must provide this information to GPJax when
instantiating a likelihood object. We do this by specifying the num_datapoints
argument.
gpx.likelihoods.Gaussian(num_datapoints=D.n)
Gaussian( # NonNegativeReal: 1 (8 B)
obs_stddev=NonNegativeReal( # 1 (8 B)
value=Array(1., dtype=float64, weak_type=True),
tag='non_negative',
numpyro_properties={}
),
num_outputs=1,
num_datapoints=50,
integrator=<gpjax.integrators.AnalyticalGaussianIntegrator object at 0x7f6500240e10>
)
Likelihood parameters
Some likelihoods, such as the Gaussian likelihood, contain parameters that we seek
to infer. In the case of the Gaussian likelihood, we have a single parameter
Ο2 that determines the observation noise. In GPJax, we can specify the value
of Ο when instantiating the likelihood object. If we do not specify a
value, then the likelihood will be initialised with a default value. In the case of
the Gaussian likelihood, the default value is 1.0. If we instead wanted to
initialise the likelihood standard deviation with a value of 0.5, then we would do
this as follows:
Gaussian( # NonNegativeReal: 1 (8 B)
obs_stddev=NonNegativeReal( # 1 (8 B)
value=Array(0.5, dtype=float64, weak_type=True),
tag='non_negative',
numpyro_properties={}
),
num_outputs=1,
num_datapoints=50,
integrator=<gpjax.integrators.AnalyticalGaussianIntegrator object at 0x7f6500240e10>
)
Prediction
The predict method of a likelihood object transforms the latent distribution of
the Gaussian process. In the case of a Gaussian likelihood, this simply applies the
observational noise value to the diagonal values of the covariance matrix. For other
likelihoods, this may be a more complex transformation. For example, the Bernoulli
likelihood transforms the latent distribution of the Gaussian process into a
distribution over binary values.
We visualise this below for the Gaussian likelihood function. In blue we can see
samples of fβ, whilst in red we see samples of
yβ.
In the above figure, we can see the latent samples being constrained to be either 0 or
1 when a Bernoulli likelihood is specified. This is achieved by the
inverse link_functionΞ·(β ) of the likelihood. The link function is a
deterministic function that maps the latent distribution of the Gaussian process to
the support of the likelihood function. For example, the link function of the
Bernoulli likelihood that is used in GPJax is the inverse probit function
Ξ·(x)=0.5(1+Ξ¦(2βxβ)β(1β2)),
where Ξ¦ is the cumulative distribution function of the standard normal
distribution.
A table of commonly used link functions and their corresponding likelihood can be
found here.
Expected log likelihood
The final method that is associated with a likelihood function in GPJax is the
expected log-likelihood. This term is evaluated in the
stochastic variational Gaussian process in the ELBO term. For a
variational approximation q(f)=N(fβ£m,S), the ELBO can be written as
As both q(f) and p(f) are Gaussian distributions, the Kullback-Leibler term can
be analytically computed. However, the expectation term is not always so easy to
compute. Fortunately, the bound in \eqref{eq:elbo} can be decomposed as a sum of the
datapoints
This simplifies computation of the expectation as it is now a series of N
1-dimensional integrals. As such, GPJax by default uses quadrature to compute these
integrals. However, for some likelihoods, such as the Gaussian likelihood, the
expectation can be computed analytically. In these cases, we can supply an object
that inherits from AbstractIntegrator to the likelihood upon instantiation. To see
this, let us consider a Gaussian likelihood where we'll first define a variational
approximation to the posterior.
Now that we have the variational mean and variational (co)variance, we can compute
the expected log-likelihood using the expected_log_likelihood method of the
likelihood object.