Skip to content

Fit

Fit wrappers (fit_scipy, fit_lbfgs, fit) for state-space GPs.

These thin wrappers validate inputs, optionally sort time-series data, and delegate to the existing gpjax.fit.fit_* optimisers with a state_space_mll-derived negative-log-likelihood objective.

fit_scipy

fit_scipy(
    *,
    model,
    train_data: Dataset,
    observation_mask=None,
    max_iters: int = 500,
    verbose: bool = True,
    safe: bool = True,
)

Fit a state-space posterior with SciPy's L-BFGS-B.

Thin wrapper around gpx.fit_scipy. Validates data, sorts if necessary, and uses state_space_mll as the objective.

Example:

    >>> import jax.numpy as jnp
    >>> import gpjax as gpx
    >>> from gpjax.state_space import StateSpacePrior, fit_scipy
    >>> X = jnp.linspace(0.0, 5.0, 20).reshape(-1, 1)
    >>> y = jnp.sin(X)
    >>> prior = StateSpacePrior(
    ...     mean_function=gpx.mean_functions.Zero(),
    ...     kernel=gpx.kernels.Matern32(lengthscale=1.0, variance=1.0),
    ... )
    >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=20, obs_stddev=0.1)
    >>> posterior = prior * likelihood
    >>> fitted, history = fit_scipy(
    ...     model=posterior,
    ...     train_data=gpx.Dataset(X=X, y=y),
    ...     max_iters=2,
    ...     verbose=False,
    ... )
    >>> bool(jnp.all(jnp.isfinite(history)))
    True

fit_lbfgs

fit_lbfgs(
    *,
    model,
    train_data: Dataset,
    observation_mask=None,
    max_iters: int = 500,
    safe: bool = True,
)

Fit a state-space posterior with Optax's L-BFGS (while_loop driver).

Thin wrapper around gpx.fit_lbfgs.

Example:

    >>> import jax.numpy as jnp
    >>> import gpjax as gpx
    >>> from gpjax.state_space import StateSpacePrior, fit_lbfgs
    >>> X = jnp.linspace(0.0, 5.0, 20).reshape(-1, 1)
    >>> y = jnp.sin(X)
    >>> prior = StateSpacePrior(
    ...     mean_function=gpx.mean_functions.Zero(),
    ...     kernel=gpx.kernels.Matern32(lengthscale=1.0, variance=1.0),
    ... )
    >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=20, obs_stddev=0.1)
    >>> posterior = prior * likelihood
    >>> fitted, history = fit_lbfgs(
    ...     model=posterior,
    ...     train_data=gpx.Dataset(X=X, y=y),
    ...     max_iters=2,
    ... )
    >>> fitted is not None
    True

fit

fit(
    *,
    model,
    train_data: Dataset,
    optim,
    observation_mask=None,
    key=None,
    num_iters: int = 100,
    batch_size: int = -1,
    log_rate: int = 10,
    verbose: bool = True,
    unroll: int = 1,
    safe: bool = True,
)

Fit a state-space posterior with Optax (gradient-descent style).

Thin wrapper around gpx.fit. Rejects batch_size != -1 because state-space MLL is intrinsically full-batch (the temporal scan cannot be minibatched without breaking the Markov chain).

Example:

    >>> import jax.numpy as jnp
    >>> import jax.random as jr
    >>> import optax as ox
    >>> import gpjax as gpx
    >>> from gpjax.state_space import StateSpacePrior, fit
    >>> X = jnp.linspace(0.0, 5.0, 20).reshape(-1, 1)
    >>> y = jnp.sin(X)
    >>> prior = StateSpacePrior(
    ...     mean_function=gpx.mean_functions.Zero(),
    ...     kernel=gpx.kernels.Matern32(lengthscale=1.0, variance=1.0),
    ... )
    >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=20, obs_stddev=0.1)
    >>> posterior = prior * likelihood
    >>> fitted, history = fit(
    ...     model=posterior,
    ...     train_data=gpx.Dataset(X=X, y=y),
    ...     optim=ox.adam(1e-2),
    ...     num_iters=2,
    ...     key=jr.key(0),
    ...     verbose=False,
    ... )
    >>> bool(jnp.all(jnp.isfinite(history)))
    True