In this notebook we demonstrate the Orthogonal Additive Kernel (OAK) of
Lu, Boukouvalas & Hensman (2022).
OAK provides an interpretable additive Gaussian process model that decomposes
the target function into main effects and interaction terms, whilst remaining
a valid positive-definite kernel. The key ingredients are:
A per-dimension constrained SE kernel that is orthogonal to the
constant function under the input density.
Newton-Girard recursion to efficiently combine these constrained
kernels into elementary symmetric polynomials up to a chosen interaction
order.
Analytic Sobol indices that quantify the relative importance of each
interaction order, enabling practitioners to understand which features
and feature interactions drive the model's predictions.
We illustrate the full workflow on the UCI Auto MPG dataset.
# Enable Float64 for more stable matrix inversions.fromjaximportconfigconfig.update("jax_enable_x64",True)fromexamples.utilsimportuse_mpl_styleimportjax.numpyasjnpimportjax.randomasjrfromjaxtypingimportinstall_import_hookimportmatplotlibasmplimportmatplotlib.pyplotaspltimportnumpyasnpwithinstall_import_hook("gpjax","beartype.beartype"):importgpjaxasgpxfromgpjax.kernels.additiveimport(OrthogonalAdditiveKernel,predict_first_order,rank_first_order,sobol_indices,)fromgpjax.parametersimportParameterkey=jr.key(123)use_mpl_style()colours=mpl.rcParams["axes.prop_cycle"].by_key()["color"]
Mathematical background
Additive GP decomposition
A standard GP with a single kernel k(x,xβ²) treats all
input dimensions jointly. An additive GP instead decomposes the latent
function as
where f0β is a constant offset, fdβ are first-order (main) effects,
fddβ²β are second-order interactions, and so on. Truncating at a maximum
interaction order D~β€D yields a model that scales gracefully
whilst retaining interpretability.
The identifiability problem
A naive additive decomposition is unidentifiable: one can freely shift
mass between the constant term and a main effect, or between a main effect
and an interaction. Lu et al. resolve this by requiring each component to
be orthogonal to all lower-order components under the input density
p(x). In particular, the first-order components satisfy
β«fdβ(xdβ)p(xdβ)dxdβ=0βd.
Constrained SE kernel
Assuming a standard normal input density p(xdβ)=N(0,1), the
orthogonality constraint can be enforced analytically. The constrained
SE kernel is
where k(x,y)=Ο2exp(β2β2(xβy)2β) is
the standard SE kernel with lengthscale β and variance Ο2.
The subtracted projection term removes the component of k that lies along
the constant function under the N(0,1) measure.
Newton-Girard recursion
The additive kernel across all interaction orders up to D~ is
where eββ denotes the β-th elementary symmetric polynomial
and Οβ2β are learnable order variances. Computing
eββ directly via the combinatorial definition would be prohibitively
expensive; instead GPJax uses the Newton-Girard identities which
express eββ recursively in terms of power sums
skβ=βd=1Dβzdkβ:
where Ξ±=(K+Οn2βI)β1y and Edβ
is the matrix-level elementary symmetric polynomial of the per-dimension
integral matrices (see Appendix G.1 of the paper).
Dataset
We use the UCI Auto MPG
dataset, which contains fuel consumption data for 392 cars described by 7
continuous features (cylinders, displacement, horsepower, weight,
acceleration, model year, and origin).
Because the OAK kernel's constrained form assumes a standard normal input
density (ΞΌ=0, Ο2=1), we fit a per-feature normalising flow
that maps each marginal to an approximately standard normal distribution.
Targets are z-score standardised. This transformation of the inputs data
is crucial for the OAK model to work correctly, as the orthogonality
constraint is defined with respect to the input density.
fromucimlrepoimportfetch_ucirepoauto_mpg=fetch_ucirepo(id=9)X_raw=auto_mpg.data.featuresy_raw=auto_mpg.data.targets# Drop rows with missing valuescomplete_rows=~(X_raw.isna().any(axis=1)|y_raw.isna().any(axis=1))X_all=X_raw[complete_rows].values.astype(np.float64)y_all=y_raw[complete_rows].values.astype(np.float64)feature_names=list(X_raw.columns)num_features=X_all.shape[1]print(f"Dataset: {X_all.shape[0]} observations, {num_features} features")print(f"Features: {feature_names}")
The constrained SE kernel assumes p(xdβ)=N(0,1). Simple
z-scoring removes the first two moments but cannot correct skewness or
heavy tails. We therefore fit a lightweight per-feature normalising flow
(Shift β Log β Standardise β SinhArcsinh) that maps each marginal to an
approximately standard normal distribution.
y_mean,y_std=y_all.mean(axis=0),y_all.std(axis=0)y_standardised=(y_all-y_mean)/y_stdnum_observations=y_standardised.shape[0]key,split_key=jr.split(key)permutation=jr.permutation(split_key,num_observations)num_train=int(0.8*num_observations)train_idx=permutation[:num_train]test_idx=permutation[num_train:]y_train=jnp.array(y_standardised[train_idx])y_test=jnp.array(y_standardised[test_idx])X_train_original=X_all[train_idx]X_test_original=X_all[test_idx]flows=fit_all_normalising_flows(jnp.asarray(X_train_original))defapply_flows(X_original:np.ndarray)->jnp.ndarray:"""Transform each feature column through its fitted normalising flow."""returnjnp.column_stack([flows[d](jnp.asarray(X_original[:,d]))fordinrange(num_features)])X_train=apply_flows(X_train_original)X_test=apply_flows(X_test_original)train_data=gpx.Dataset(X=X_train,y=y_train)test_data=gpx.Dataset(X=X_test,y=y_test)
Fitting an OAK GP
We create D independent RBF base kernels, one per input dimension, each
operating on a single dimension via active_dims=[i]. These are wrapped
inside OrthogonalAdditiveKernel with max_order=D (i.e. we allow all
interaction orders). The kernel is then used in a standard conjugate GP
workflow: define a prior and Gaussian likelihood, form the posterior, and
optimise hyperparameters by maximising the marginal log-likelihood.
Optimization terminated successfully.
Current function value: 107.174483
Iterations: 56
Function evaluations: 60
Gradient evaluations: 60
Sobol indices
We now compute the analytic Sobol indices for each interaction order.
These indicate what fraction of the posterior variance is explained by
first-order (main) effects, second-order interactions, and so on.
noise_variance=float(jnp.square(opt_posterior.likelihood.obs_stddev[...]))fitted_kernel=opt_posterior.prior.kernelsobol_values=sobol_indices(fitted_kernel,X_train,y_train,noise_variance)fig,ax=plt.subplots(figsize=(7,3))orders=jnp.arange(1,len(sobol_values)+1)ax.bar(orders,sobol_values,color=colours[1])ax.set_xlabel("Interaction order")ax.set_ylabel("Sobol index")ax.set_title("Sobol indices by interaction order")ax.set_xticks(np.arange(1,len(sobol_values)+1))
[<matplotlib.axis.XTick at 0x7f22f8c694d0>,
<matplotlib.axis.XTick at 0x7f22f89dee10>,
<matplotlib.axis.XTick at 0x7f23100b3f90>]
Typically the first-order (main) effects dominate, with higher-order
interactions contributing progressively less. This validates the additive
modelling assumption for this dataset.
Decomposed additive components
One of the key advantages of the OAK model is the ability to visualise
each feature's individual contribution to the prediction. We extract the
top 4 first-order main effects and plot the posterior mean and a
Β±2Ο credible band for each, alongside a histogram of the
training inputs.
For each feature d, we evaluate the constrained kernel
k~dβ(xββ,Xtrain,dβ) between a 1-D grid and the
training points, then form the conditional mean and variance in the usual
GP way.
num_top_features=3num_grid_points=300feature_scores=rank_first_order(fitted_kernel,X_train,y_train,noise_variance)top_feature_indices=jnp.argsort(-feature_scores)[:num_top_features]fig,axes=plt.subplots(nrows=1,ncols=num_top_features,figsize=(12,3))forplot_idx,axinenumerate(axes.flat):feature_dim=int(top_feature_indices[plot_idx])feature_name=feature_names[feature_dim]grid_low=float(X_train[:,feature_dim].min())grid_high=float(X_train[:,feature_dim].max())grid=jnp.linspace(grid_low,grid_high,num_grid_points)effect_mean,effect_variance=predict_first_order(fitted_kernel,X_train,y_train,noise_variance,feature_dim,grid)effect_std=jnp.sqrt(effect_variance)grid_original_scale=flows[feature_dim].inv(grid)ax.plot(grid_original_scale,effect_mean,color=colours[1],linewidth=2,label="Posterior mean",)ax.fill_between(grid_original_scale,effect_mean-2*effect_std,effect_mean+2*effect_std,alpha=0.2,color=colours[1],label=r"$\pm 2\sigma$",)histogram_ax=ax.twinx()histogram_ax.hist(X_train_original[:,feature_dim],bins=20,alpha=0.15,color=colours[0],density=True,)histogram_ax.set_yticks([])ax.set_xlabel(feature_name)ax.set_ylabel("Effect")ax.set_title(f"{feature_name} (dim {feature_dim})")ax.legend(loc="best",fontsize=8)fig.suptitle(f"Top {num_top_features} first-order main effects",fontsize=14,y=1.05)
Text(0.5, 1.05, 'Top 3 first-order main effects')
Each panel shows how the OAK model attributes predictive variation to
individual features. Features with large, clearly non-zero effects are
those that the model identifies as important for predicting fuel
consumption. The uncertainty bands widen in regions where training data
are sparse, reflecting the GP's epistemic uncertainty.