Fit
Fit wrappers (fit_scipy, fit_lbfgs, fit) for state-space GPs.
These thin wrappers validate inputs, optionally sort time-series data, and
delegate to the existing gpjax.fit.fit_* optimisers with a
state_space_mll-derived negative-log-likelihood objective.
fit_scipy
fit_scipy(
*,
model,
train_data: Dataset,
observation_mask=None,
max_iters: int = 500,
verbose: bool = True,
safe: bool = True,
)
Fit a state-space posterior with SciPy's L-BFGS-B.
Thin wrapper around gpx.fit_scipy. Validates data, sorts if necessary,
and uses state_space_mll as the objective.
Example:
>>> import jax.numpy as jnp
>>> import gpjax as gpx
>>> from gpjax.state_space import StateSpacePrior, fit_scipy
>>> X = jnp.linspace(0.0, 5.0, 20).reshape(-1, 1)
>>> y = jnp.sin(X)
>>> prior = StateSpacePrior(
... mean_function=gpx.mean_functions.Zero(),
... kernel=gpx.kernels.Matern32(lengthscale=1.0, variance=1.0),
... )
>>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=20, obs_stddev=0.1)
>>> posterior = prior * likelihood
>>> fitted, history = fit_scipy(
... model=posterior,
... train_data=gpx.Dataset(X=X, y=y),
... max_iters=2,
... verbose=False,
... )
>>> bool(jnp.all(jnp.isfinite(history)))
True
fit_lbfgs
fit_lbfgs(
*,
model,
train_data: Dataset,
observation_mask=None,
max_iters: int = 500,
safe: bool = True,
)
Fit a state-space posterior with Optax's L-BFGS (while_loop driver).
Thin wrapper around gpx.fit_lbfgs.
Example:
>>> import jax.numpy as jnp
>>> import gpjax as gpx
>>> from gpjax.state_space import StateSpacePrior, fit_lbfgs
>>> X = jnp.linspace(0.0, 5.0, 20).reshape(-1, 1)
>>> y = jnp.sin(X)
>>> prior = StateSpacePrior(
... mean_function=gpx.mean_functions.Zero(),
... kernel=gpx.kernels.Matern32(lengthscale=1.0, variance=1.0),
... )
>>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=20, obs_stddev=0.1)
>>> posterior = prior * likelihood
>>> fitted, history = fit_lbfgs(
... model=posterior,
... train_data=gpx.Dataset(X=X, y=y),
... max_iters=2,
... )
>>> fitted is not None
True
fit
fit(
*,
model,
train_data: Dataset,
optim,
observation_mask=None,
key=None,
num_iters: int = 100,
batch_size: int = -1,
log_rate: int = 10,
verbose: bool = True,
unroll: int = 1,
safe: bool = True,
)
Fit a state-space posterior with Optax (gradient-descent style).
Thin wrapper around gpx.fit. Rejects batch_size != -1 because
state-space MLL is intrinsically full-batch (the temporal scan cannot be
minibatched without breaking the Markov chain).
Example:
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> import optax as ox
>>> import gpjax as gpx
>>> from gpjax.state_space import StateSpacePrior, fit
>>> X = jnp.linspace(0.0, 5.0, 20).reshape(-1, 1)
>>> y = jnp.sin(X)
>>> prior = StateSpacePrior(
... mean_function=gpx.mean_functions.Zero(),
... kernel=gpx.kernels.Matern32(lengthscale=1.0, variance=1.0),
... )
>>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=20, obs_stddev=0.1)
>>> posterior = prior * likelihood
>>> fitted, history = fit(
... model=posterior,
... train_data=gpx.Dataset(X=X, y=y),
... optim=ox.adam(1e-2),
... num_iters=2,
... key=jr.key(0),
... verbose=False,
... )
>>> bool(jnp.all(jnp.isfinite(history)))
True