Gaussian Processes for Vector Fields and Ocean Current Modelling
In this notebook, we use Gaussian processes to learn vector-valued functions. We will
be recreating the results by Berlinghieri et al.
(2023) by an application to real-world ocean
surface velocity data, collected via surface drifters.
Surface drifters are measurement devices that measure the dynamics and circulation
patterns of the world's oceans. Studying and predicting ocean currents are important
to climate research, for example, forecasting and predicting oil spills, oceanographic
surveying of eddies and upwelling, or providing information on the distribution of
biomass in ecosystems. We will be using the Gulf Drifters Open
dataset, which contains all publicly available
surface drifter trajectories from the Gulf of Mexico spanning 28 years.
fromdataclassesimport(dataclass,field,)fromjaximport(config,hessian,)importjax.numpyasjnpimportjax.randomasjrfromjaxtypingimport(Array,Float,install_import_hook,)frommatplotlibimportrcParamsimportmatplotlib.pyplotaspltimportnumpyro.distributionsasnpdimportpandasaspdfromexamples.utilsimportuse_mpl_stylefromgpjax.kernels.computationsimportDenseKernelComputationconfig.update("jax_enable_x64",True)withinstall_import_hook("gpjax","beartype.beartype"):importgpjaxasgpxfromgpjax.parametersimportParameter# set the default style for plottinguse_mpl_style()key=jr.key(42)colors=rcParams["axes.prop_cycle"].by_key()["color"]
Data loading and preprocessing
The real dataset has been binned into an N=34Γ16 grid, equally spaced over the
longitude-latitude interval [β90.8,β83.8]Γ[24.0,27.5]. Each bin has a size
β0.21Γ0.21, and contains the average velocity across all measurements
that fall inside it.
We will call this binned ocean data the ground truth, and label it with the vector
field Fβ‘F(x), where x=(x(0),x(1))T, with a vector basis in the standard Cartesian
directions (dimensions will be indicated by superscripts).
We shall label the ground truth D0β={(x0,iβ,y0,iβ)}i=1Nβ, where y0,iβ is the 2-dimensional velocity
vector at the i-th location, x0,iβ. The training dataset contains
simulated measurements from ocean drifters DTβ={(xT,iβ,yT,iβ)}i=1NTββ, NTβ=20 in this case (the subscripts
indicate the ground truth and the simulated measurements respectively).
# function to place data from csv into correct array shapedefprepare_data(df):pos=jnp.array([df["lon"],df["lat"]])vel=jnp.array([df["ubar"],df["vbar"]])# extract shape stored as 'metadata' in the test datatry:shape=(int(df["shape"][1]),int(df["shape"][0]))# shape = (34,16)returnpos,vel,shapeexceptKeyError:returnpos,vel# loading in datagulf_data_train=pd.read_csv("https://raw.githubusercontent.com/JaxGaussianProcesses/static/main/data/gulfdata_train.csv")gulf_data_test=pd.read_csv("https://raw.githubusercontent.com/JaxGaussianProcesses/static/main/data/gulfdata_test.csv")pos_test,vel_test,shape=prepare_data(gulf_data_test)pos_train,vel_train=prepare_data(gulf_data_train)fig,ax=plt.subplots(1,1,figsize=(6,3))ax.quiver(pos_test[0],pos_test[1],vel_test[0],vel_test[1],color=colors[0],label="Ocean Current",angles="xy",scale=10,)ax.quiver(pos_train[0],pos_train[1],vel_train[0],vel_train[1],color=colors[1],alpha=0.7,label="Drifter",angles="xy",scale=10,)ax.set(xlabel="Longitude",ylabel="Latitude",)ax.legend(framealpha=0.0,ncols=2,fontsize="medium",bbox_to_anchor=(0.5,-0.3),loc="lower center",)
<matplotlib.legend.Legend at 0x7f93b6178e10>
Problem Setting
We aim to obtain estimates for F at the set of points {x0,iβ}i=1Nβ using Gaussian processes, followed by a comparison
of the latent model to the ground truth D0β. Note that D0β is not passed into any
functions used by GPJax, and is only used to compare against the two GP models at the
end of the notebook.
Since F is a vector-valued function, we require GPs that can directly learn
vector-valued functions1. To implement this in GPJax, the problem
can be changed to learn a scalar-valued function by 'massaging' the data into a
2NΓ2N problem, such that each dimension of our GP is associated with a
component of yT,iβ.
For a particular measurement y (training or testing) at location
x, the components (y(0),y(1)) are described by the latent vector
field F, such that
y=F(x)=(f(0)(x)f(1)(x)β),
where each f(z)(x),zβ{0,1} is a scalar-valued
function.
Now consider the scalar-valued function g:R2Γ{0,1}βR, such that
g(x,0)=f(0)(x),and g(x,1)=f(1)(x).
We have increased the input dimension by 1, from the 2D x to the 3D
X=(x,0) or X=(x,1).
By choosing the value of the third dimension, 0 or 1, we may now incorporate this
information into the computation of the kernel. We therefore make new 3D datasets
DT,3Dβ={(XT,iβ,YT,iβ)}i=02NTββ and D0,3Dβ={(X0,iβ,Y0,iβ)}i=02Nβ that incorporates this new labelling, such that for each dataset
(indicated by the subscript D=0 or D=T),
XD,iβ=(xD,iβ,z),
and
YD,iβ=yD,i(z)β,
where z=0 if i is odd and z=1 if i is even.
# Change vectors x -> X = (x,z), and vectors y -> Y = (y,z) via the artificial z labeldeflabel_position(data):# introduce alternating z labeln_points=len(data[0])label=jnp.tile(jnp.array([0.0,1.0]),n_points)returnjnp.vstack((jnp.repeat(data,repeats=2,axis=1),label)).T# change vectors y -> Y by reshaping the velocity measurementsdefstack_velocity(data):returndata.T.flatten().reshape(-1,1)defdataset_3d(pos,vel):returngpx.Dataset(label_position(pos),stack_velocity(vel))# label and place the training data into a Dataset object to be used by GPJaxdataset_train=dataset_3d(pos_train,vel_train)# we also require the testing data to be relabelled for later use, such that we can query the 2Nx2N GP at the test pointsdataset_ground_truth=dataset_3d(pos_test,vel_test)
Velocity (dimension) decomposition
Having labelled the data, we are now in a position to use GPJax to learn the function
g, and hence F. A naive approach to the problem is to apply a GP prior
directly to the velocities of each dimension independently, which is called the
velocity GP. For our prior, we choose an isotropic mean 0 over all dimensions of the
GP, and a piecewise kernel that depends on the z labels of the inputs, such that for
two inputs X=(x,z) and Xβ²=(xβ²,zβ²),
kvelβ(X,Xβ²)={k(z)(x,xβ²)0β if z=zβ² if zξ =zβ²,β
where k(z)(x,xβ²) are the user chosen
kernels for each dimension. What this means is that there are no correlations between
the x(0) and x(1) dimensions for all choices X and
Xβ², since there are no off-diagonal elements in the Gram matrix
populated by this choice.
To implement this approach in GPJax, we define VelocityKernel in the following cell,
following the steps outlined in the custom kernels
notebook.
This modular implementation takes the choice of user kernels as its class attributes:
kernel0 and kernel1. We must additionally pass the argument active_dims = [0,1],
which is an attribute of the base class AbstractKernel, into the chosen kernels.
This is necessary such that the subsequent likelihood optimisation does not optimise
over the artificial label dimension.
classVelocityKernel(gpx.kernels.AbstractKernel):def__init__(self,kernel0:gpx.kernels.AbstractKernel=gpx.kernels.RBF(active_dims=[0,1]),kernel1:gpx.kernels.AbstractKernel=gpx.kernels.RBF(active_dims=[0,1]),):self.kernel0=kernel0self.kernel1=kernel1super().__init__(compute_engine=DenseKernelComputation())def__call__(self,X:Float[Array,"1 D"],Xp:Float[Array,"1 D"])->Float[Array,"1"]:# standard RBF-SE kernel is x and x' are on the same output, otherwise returns 0z=jnp.array(X[2],dtype=int)zp=jnp.array(Xp[2],dtype=int)# achieve the correct value via 'switches' that are either 1 or 0k0_switch=((z+1)%2)*((zp+1)%2)k1_switch=z*zpreturnk0_switch*self.kernel0(X,Xp)+k1_switch*self.kernel1(X,Xp)
GPJax implementation
Next, we define the model in GPJax. The prior is defined using
kvelβ(X,Xβ²) and 0 mean and 0
observation noise. We choose a Gaussian marginal log-likelihood (MLL).
definitialise_gp(kernel,mean,dataset):prior=gpx.gps.Prior(mean_function=mean,kernel=kernel)likelihood=gpx.likelihoods.Gaussian(num_datapoints=dataset.n,obs_stddev=jnp.array([1.0e-3],dtype=jnp.float64))posterior=prior*likelihoodreturnposterior# Define the velocity GPmean=gpx.mean_functions.Zero()kernel=VelocityKernel()velocity_posterior=initialise_gp(kernel,mean,dataset_train)
With a model now defined, we can proceed to optimise the hyperparameters
of our likelihood over D0β. This is done by minimising the MLL using BFGS. We also
plot its value at each step to visually confirm that we have found the minimum. See
the introduction to Gaussian
Processes notebook for
more information on optimising the MLL.
defoptimise_mll(posterior,dataset,NIters=1000,key=key):# define the MLL using dataset_trainobjective=lambdap,d:-gpx.objectives.conjugate_mll(p,d)# Optimise to minimise the MLLopt_posterior,history=gpx.fit_scipy(model=posterior,objective=objective,train_data=dataset,trainable=Parameter,)returnopt_posterioropt_velocity_posterior=optimise_mll(velocity_posterior,dataset_train)
Optimization terminated successfully.
Current function value: -26.620707
Iterations: 42
Function evaluations: 70
Gradient evaluations: 70
Comparison
We next obtain the latent distribution of the GP of g at x0,iβ, then
extract its mean and standard at the test locations,
Flatentβ(x0,iβ), as well as the standard deviation (we
will use it at the very end).
deflatent_distribution(opt_posterior,pos_3d,dataset_train):latent=opt_posterior.predict(pos_3d,train_data=dataset_train)latent_mean=latent.meanlatent_std=latent.stddev()returnlatent_mean,latent_std# extract latent mean and std of g, redistribute into vectors to model Fvelocity_mean,velocity_std=latent_distribution(opt_velocity_posterior,dataset_ground_truth.X,dataset_train)dataset_latent_velocity=dataset_3d(pos_test,velocity_mean)
We now replot the ground truth (testing data) D0β, the predicted
latent vector field Flatentβ(xiβ), and a heatmap of the
residuals at each location R(x0,iβ)=y0,iββFlatentβ(x0,iβ), as well as
β£β£R(x0,iβ)β£β£.
# Residuals between ground truth and estimatedefplot_vector_field(ax,dataset,**kwargs):ax.quiver(dataset.X[::2][:,0],dataset.X[::2][:,1],dataset.y[::2],dataset.y[1::2],**kwargs,)defprepare_ax(ax,X,Y,title,**kwargs):ax.set(xlim=[X.min()-0.1,X.max()+0.1],ylim=[Y.min()+0.1,Y.max()+0.1],aspect="equal",title=title,ylabel="latitude",**kwargs,)defresiduals(dataset_latent,dataset_ground_truth):returnjnp.sqrt((dataset_latent.y[::2]-dataset_ground_truth.y[::2])**2+(dataset_latent.y[1::2]-dataset_ground_truth.y[1::2])**2)defplot_fields(dataset_ground_truth,dataset_trajectory,dataset_latent,shape=shape,scale=10):X=dataset_ground_truth.X[:,0][::2]Y=dataset_ground_truth.X[:,1][::2]# make figurefig,ax=plt.subplots(1,3,figsize=(12.0,3.0),sharey=True)# ground truthplot_vector_field(ax[0],dataset_ground_truth,color=colors[0],label="Ocean Current",angles="xy",scale=scale,)plot_vector_field(ax[0],dataset_trajectory,color=colors[1],label="Drifter",angles="xy",scale=scale,)prepare_ax(ax[0],X,Y,"Ground Truth",xlabel="Longitude")# Latent estimate of vector field Fplot_vector_field(ax[1],dataset_latent,color=colors[0],angles="xy",scale=scale)plot_vector_field(ax[1],dataset_trajectory,color=colors[1],angles="xy",scale=scale)prepare_ax(ax[1],X,Y,"GP Estimate",xlabel="Longitude")# residualsresiduals_vel=jnp.flip(residuals(dataset_latent,dataset_ground_truth).reshape(shape),axis=0)im=ax[2].imshow(residuals_vel,extent=[X.min(),X.max(),Y.min(),Y.max()],cmap="jet",vmin=0,vmax=1.0,interpolation="spline36",)plot_vector_field(ax[2],dataset_trajectory,color=colors[1],angles="xy",scale=scale)prepare_ax(ax[2],X,Y,"Residuals",xlabel="Longitude")fig.colorbar(im,fraction=0.027,pad=0.04,orientation="vertical")fig.legend(framealpha=0.0,ncols=2,fontsize="medium",bbox_to_anchor=(0.5,-0.03),loc="lower center",)plot_fields(dataset_ground_truth,dataset_train,dataset_latent_velocity)
From the latent estimate we can see the velocity GP struggles to
reconstruct features of the ground truth. This is because our construction of the
kernel placed an independent prior on each physical dimension, which cannot be
assumed. Therefore, we need a different approach that can implicitly incorporate this
dependence at a fundamental level. To achieve this we will require a Helmholtz
Decomposition.
Helmholtz decomposition
In 2 dimensions, a twice continuously differentiable and compactly supported vector
field F:R2βR2 can be expressed as the sum
of the gradient of a scalar potential Ξ¦:R2βR,
called the potential function, and the vorticity operator of another scalar potential
Ξ¨:R2βR, called the stream function (Berlinghieri
et al. (2023)) such that F=gradΞ¦+rotΞ¨, where gradΞ¦:=[βΞ¦/βx(0)βΞ¦/βx(1)β] and rotΞ¨:=[βΞ¨/βx(1)ββΞ¨/βx(0)β].
The 2 dimensional decomposition motivates a different approach: placing priors on
Ξ¨ and Ξ¦, allowing us to make assumptions directly about fundamental
properties of F. If we choose independent GP priors such that Ξ¦βΌGP(0,kΞ¦β) and Ξ¨βΌGP(0,kΞ¨β), then FβΌGP(0,kHelmβ)
(since acting linear operations on a GPs give GPs).
where x(z) and (xβ²)(zβ²) are the z and zβ² components of
X and Xβ² respectively.
We compute the second derivatives using jax.hessian. In the following
implementation, for a kernel k(x,xβ²), this computes the
Hessian matrix with respect to the components of x
βx(z)βx(zβ²)β2k(x,xβ²)β.
Note that we have operated βx(z)ββ, notβ(xβ²)(z)ββ, as the boxed equation
suggests. This is not an issue if we choose stationary kernels k(x,xβ²)=k(xβxβ²) , as the partial
derivatives with respect to the components have the following exchange symmetry:
βx(z)ββ=ββ(xβ²)(z)ββ,
for either z.
@dataclassclassHelmholtzKernel(gpx.kernels.stationary.StationaryKernel):# initialise Phi and Psi kernels as any stationary kernel in gpJaxpotential_kernel:gpx.kernels.stationary.StationaryKernel=field(default_factory=lambda:gpx.kernels.RBF(active_dims=[0,1]))stream_kernel:gpx.kernels.stationary.StationaryKernel=field(default_factory=lambda:gpx.kernels.RBF(active_dims=[0,1]))compute_engine=DenseKernelComputation()def__call__(self,X:Float[Array,"1 D"],Xp:Float[Array,"1 D"])->Float[Array,"1"]:# obtain indices for k_helm, implement in the correct sign between the derivativesz=jnp.array(X[2],dtype=int)zp=jnp.array(Xp[2],dtype=int)sign=(-1)**(z+zp)# convert to array to correctly index, -ve sign due to exchange symmetry (only true for stationary kernels)potential_dvtve=-jnp.array(hessian(self.potential_kernel)(X,Xp),dtype=jnp.float64)[z][zp]stream_dvtve=-jnp.array(hessian(self.stream_kernel)(X,Xp),dtype=jnp.float64)[1-z][1-zp]returnpotential_dvtve+sign*stream_dvtve
GPJax implementation
We repeat the same steps as with the velocity GP model, replacing VelocityKernel
with HelmholtzKernel.
# Redefine Gaussian process with Helmholtz kernelkernel=HelmholtzKernel()helmholtz_posterior=initialise_gp(kernel,mean,dataset_train)# Optimise hyperparameters using BFGSopt_helmholtz_posterior=optimise_mll(helmholtz_posterior,dataset_train)
Optimization terminated successfully.
Current function value: -28.611975
Iterations: 35
Function evaluations: 58
Gradient evaluations: 58
Comparison
We again plot the ground truth (testing data) D0β, the predicted latent vector field
Flatentβ(x0,iβ), and a heatmap of the residuals at each
location R(x0,iβ)=y0,iββFlatentβ(x0,iβ) and β£β£R(x0,iβ)β£β£.
# obtain latent distribution, extract x and y values over ghelmholtz_mean,helmholtz_std=latent_distribution(opt_helmholtz_posterior,dataset_ground_truth.X,dataset_train)dataset_latent_helmholtz=dataset_3d(pos_test,helmholtz_mean)plot_fields(dataset_ground_truth,dataset_train,dataset_latent_helmholtz)
Visually, the Helmholtz model performs better than the velocity model, preserving the
local structure of the F. Since we placed priors on Ξ¦ and Ξ¨, the
construction of F allows for correlations between the dimensions (non-zero
off-diagonal elements in the Gram matrix populated by
kHelmβ(X,Xβ²) ).
Negative log predictive densities
Lastly, we directly compare the velocity and Helmholtz models by computing the
negative log predictive
densities for each
model. This is a quantitative metric that measures the probability of the ground truth
given the data.
NLPD=βi=1β2Nβlog(p(Yiβ=Y0,iββ£Xiβ)),
where each p(Yiββ£Xiβ) is the marginal Gaussian
distribution over Yiβ at each test location, and Yi,0β is the i-th
component of the (massaged) test data that we reserved at the beginning of the
notebook in D0β. A smaller value is better, since the deviation of the ground truth
and the model are small in this case.
# ensure testing data alternates between x0 and x1 componentsdefnlpd(mean,std,vel_test):vel_query=jnp.column_stack((vel_test[0],vel_test[1])).flatten()normal=npd.Normal(loc=mean,scale=std)return-jnp.sum(normal.log_prob(vel_query))# compute nlpd for velocity and helmholtznlpd_vel=nlpd(velocity_mean,velocity_std,vel_test)nlpd_helm=nlpd(helmholtz_mean,helmholtz_std,vel_test)print("NLPD for Velocity: %.2f\nNLPD for Helmholtz: %.2f"%(nlpd_vel,nlpd_helm))
NLPD for Velocity: 730.13
NLPD for Helmholtz: -280.59
The Helmholtz model outperforms the velocity model, as indicated by the lower NLPD score.
Footnote
Kernels for vector-valued functions have been studied in the literature, see Alvarez et al. (2012)