In this notebook we'll give an implementation of
. In this work, the existence of a
Wasserstein barycentre between a collection of Gaussian processes is proven. When
faced with trying to average a set of probability distributions, the Wasserstein
barycentre is an attractive choice as it enables uncertainty amongst the individual
distributions to be incorporated into the averaged distribution. When compared to a
naive mean of means and mean of variances approach to computing the average
probability distributions, it can be seen that Wasserstein barycentres offer
significantly more favourable uncertainty estimation.
importtypingastpimportjax# Enable Float64 for more stable matrix inversions.fromjaximportconfigimportjax.numpyasjnpimportjax.randomasjrimportjax.scipy.linalgasjslfromjaxtypingimportinstall_import_hookimportmatplotlib.pyplotaspltimportnumpyro.distributionsasnpdfromexamples.utilsimportuse_mpl_styleconfig.update("jax_enable_x64",True)withinstall_import_hook("gpjax","beartype.beartype"):importgpjaxasgpxfromgpjax.parametersimportParameterkey=jr.key(123)# set the default style for plottinguse_mpl_style()cols=plt.rcParams["axes.prop_cycle"].by_key()["color"]
Background
Wasserstein distance
The 2-Wasserstein distance metric between two probability measures ΞΌ and Ξ½
quantifies the minimal cost required to transport the unit mass from ΞΌ to Ξ½,
or vice-versa. Typically, computing this metric requires solving a linear program.
However, when ΞΌ and Ξ½ both belong to the family of multivariate Gaussian
distributions, the solution is analytically given by
As with the Wasserstein distance, identifying the Wasserstein barycentre ΞΌΛβ
is often an computationally demanding optimisation problem. However, when all the
measures admit a multivariate Gaussian density, the barycentre
ΞΌΛβ=N(mΛ,SΛ) has analytical solutions
Identifying SΛ is achieved through a fixed-point iterative update.
Barycentre of Gaussian processes
It was shown in that the
barycentre fΛβ of a collection of Gaussian processes
{fiβ}i=1Tβ such that fiββΌGP(miβ,Kiβ) can be
found using the same solutions as in (β). For a full theoretical understanding,
we recommend reading the original paper. However, the central argument to this result
is that one can first show that the barycentre GP
fΛββΌGP(mΛ,SΛ) is non-degenerate for any finite set of
GPs {ftβ}t=1Tβ i.e., T<β. With this established, one can
show that for a n-dimensional finite Gaussian distribution fi,nβ, the
Wasserstein metric between any two Gaussian distributions fi,nβ,fj,nβ
converges to the Wasserstein metric between GPs as nββ.
In this notebook, we will demonstrate how this can be achieved in GPJax.
Dataset
We'll simulate five datasets and develop a Gaussian process posterior before
identifying the Gaussian process barycentre at a set of test points. Each dataset
will be a sine function with a different vertical shift, periodicity, and quantity
of noise.
We'll now independently learn Gaussian process posterior distributions for each
dataset. We won't spend any time here discussing how GP hyperparameters are
optimised. For advice on achieving this, see the
Regression notebook
for advice on optimisation and the
Kernels notebook for
advice on selecting an appropriate kernel.
Optimization terminated successfully.
Current function value: -31.740071
Iterations: 10
Function evaluations: 16
Gradient evaluations: 16
Optimization terminated successfully.
Current function value: 30.775673
Iterations: 12
Function evaluations: 19
Gradient evaluations: 19
Optimization terminated successfully.
Current function value: -102.543998
Iterations: 10
Function evaluations: 18
Gradient evaluations: 18
Optimization terminated successfully.
Current function value: -143.031733
Iterations: 12
Function evaluations: 22
Gradient evaluations: 22
Optimization terminated successfully.
Current function value: -271.573651
Iterations: 11
Function evaluations: 15
Gradient evaluations: 15
Computing the barycentre
In GPJax, the predictive distribution of a GP is given by a
TensorFlow Probability
distribution, making it
straightforward to extract the mean vector and covariance matrix of each GP for
learning a barycentre. We implement the fixed point scheme given in (3) in the
following cell by utilising Jax's vmap operator to speed up large matrix operations
using broadcasting in tensordot.
With a function defined for learning a barycentre, we'll now compute it using the
lax.scan operator that drastically speeds up for loops in Jax (see the
Jax documentation).
The iterative update will be executed 100 times, with convergence measured by the
difference between the previous and current iteration that we can confirm by
inspecting the sequence array in the following cell.
With a barycentre learned, we can visualise the result. We can see that the result
looks reasonable as it follows the sinusoidal curve of all the inferred GPs, and the
uncertainty bands are sensible.
In the above example, we assigned uniform weights to each of the posteriors within
the barycentre. In practice, we may have prior knowledge of which posterior is most
likely to be the correct one. Regardless of the weights chosen, the barycentre
remains a Gaussian process. We can interpolate between a pair of posterior
distributions ΞΌ1β and ΞΌ2β to visualise the corresponding barycentre
ΞΌΛβ.