In this notebook we demonstrate how to perform inference for Gaussian process models
with non-Gaussian likelihoods via maximum a posteriori (MAP). We focus on a classification task here.
fromflaximportnnximportjax# Enable Float64 for more stable matrix inversions.fromjaximportconfigimportjax.numpyasjnpimportjax.randomasjrimportjax.scipyasjspfromjaxtypingimport(Array,Float,install_import_hook,)importmatplotlib.pyplotaspltimportnumpyro.distributionsasnpdimportoptaxasoxfromexamples.utilsimportuse_mpl_stylefromgpjax.linalgimport(PSD,lower_cholesky,solve,)config.update("jax_enable_x64",True)withinstall_import_hook("gpjax","beartype.beartype"):importgpjaxasgpxfromgpjax.parametersimportParameteridentity_matrix=jnp.eye# set the default style for plottinguse_mpl_style()key=jr.key(42)cols=plt.rcParams["axes.prop_cycle"].by_key()["color"]
Dataset
With the necessary modules imported, we simulate a dataset
D=(x,y)={(xiβ,yiβ)}i=1100β with inputs
x sampled uniformly on (β1.,1) and corresponding binary outputs
<matplotlib.collections.PathCollection at 0x7f62f4cb1790>
MAP inference
We begin by defining a Gaussian process prior with a radial basis function (RBF)
kernel, chosen for the purpose of exposition. Since our observations are binary, we
choose a Bernoulli likelihood with a probit link function.
We construct the posterior through the product of our prior and likelihood.
posterior=prior*likelihoodprint(type(posterior))
<class 'gpjax.gps.NonConjugatePosterior'>
Whilst the latent function is Gaussian, the posterior distribution is non-Gaussian
since our generative model first samples the latent GP and propagates these samples
through the likelihood function's inverse link function. This step prevents us from
being able to analytically integrate the latent function's values out of our
posterior, and we must instead adopt alternative inference techniques. We begin with
maximum a posteriori (MAP) estimation, a fast inference procedure to obtain point
estimates for the latent function and the kernel's hyperparameters by maximising the
marginal log-likelihood.
We can obtain a MAP estimate by optimising the log-posterior density with
Optax's optimisers.
optimiser=ox.adam(learning_rate=0.01)opt_posterior,history=gpx.fit(model=posterior,# we use the negative lpd as we are minimisingobjective=lambdap,d:-gpx.objectives.log_posterior_density(p,d),train_data=D,optim=ox.adamw(learning_rate=0.01),num_iters=1000,key=key,trainable=Parameter,# train all parameters (default behavior))
0%| | 0/1000 [00:00<?, ?it/s]
From which we can make predictions at novel inputs, as illustrated below.
However, as a point estimate, MAP estimation is severely limited for uncertainty
quantification, providing only a single piece of information about the posterior.
Laplace approximation
The Laplace approximation improves uncertainty quantification by incorporating
curvature induced by the marginal log-likelihood's Hessian to construct an
approximate Gaussian distribution centered on the MAP estimate. Writing
p~β(fβ£D)=p(yβ£f)p(f)
as the unormalised posterior for function values f at the datapoints
x, we can expand the log of this about the posterior mode
f^β via a Taylor expansion. This gives:
that we identify as a Gaussian distribution,
p(fβ£D)βq(f):=N(f^β,[ββ2p~β(yβ£f)β£f^ββ]β1).
Since the negative Hessian is positive definite, we can use the Cholesky
decomposition to obtain the covariance matrix of the Laplace approximation at the
datapoints below.
gram,cross_covariance=(kernel.gram,kernel.cross_covariance)jitter=1e-6# Compute (latent) function value map estimates at training points:Kxx=opt_posterior.prior.kernel.gram(x)Kxx+=identity_matrix(D.n)*jitterKxx=PSD(Kxx)Lx=lower_cholesky(Kxx)f_hat=Lx@opt_posterior.latent[...]# Negative Hessian, H = -βΒ²p_tilde(y|f):graphdef,params,*static_state=nnx.split(opt_posterior,Parameter,...)defloss(params,D):model=nnx.merge(graphdef,params,*static_state)return-gpx.objectives.log_posterior_density(model,D)jacobian=jax.jacfwd(jax.jacrev(loss))(params,D)H=jacobian["latent"]["latent"][...][:,0,:,0]L=jnp.linalg.cholesky(H+identity_matrix(D.n)*jitter)# Hβ»ΒΉ = Hβ»ΒΉ I = (LLα΅)β»ΒΉ I = Lβ»α΅Lβ»ΒΉ IL_inv=jsp.linalg.solve_triangular(L,identity_matrix(D.n),lower=True)H_inv=jsp.linalg.solve_triangular(L.T,L_inv,lower=False)LH=jnp.linalg.cholesky(H_inv)laplace_approximation=npd.MultivariateNormal(f_hat.squeeze(),scale_tril=LH)
For novel inputs, we must project the above approximating distribution through the
Gaussian conditional distribution p(f(β )β£f),
This is the same approximate distribution qmapβ(f(β )), but we have perturbed
the covariance by a curvature term of
K(β )xβKxxβ1β[ββ2p~β(yβ£f)β£f^ββ]β1Kxxβ1βKx(β )β.
We take the latent distribution computed in the previous section and add this term
to the covariance to construct qLaplaceβ(f(β )).