Skip to content

Fit

fit

fit(
    *,
    model: Model,
    objective: Objective,
    train_data: Dataset,
    optim: GradientTransformation,
    params_bijection: dict[Parameter, Transform]
    | None = DEFAULT_BIJECTION,
    trainable: Filter = Parameter,
    key: KeyArray = jr.key(42),
    num_iters: int = 100,
    batch_size: int = -1,
    log_rate: int = 10,
    verbose: bool = True,
    unroll: int = 1,
    safe: bool = True,
) -> tuple[Model, jax.Array]

Train a Module model with respect to a supplied objective function. Optimisers used here should originate from Optax.

Example:

    >>> import jax.numpy as jnp
    >>> import jax.random as jr
    >>> import optax as ox
    >>> import gpjax as gpx
    >>> from gpjax.parameters import PositiveReal
    >>>
    >>> # (1) Create a dataset:
    >>> X = jnp.linspace(0.0, 10.0, 100)[:, None]
    >>> y = 2.0 * X + 1.0 + 10 * jr.normal(jr.key(0), X.shape)
    >>> D = gpx.Dataset(X, y)
    >>> # (2) Define your model:
    >>> class LinearModel(nnx.Module):
    >>>     def __init__(self, weight: float, bias: float):
    >>>         self.weight = PositiveReal(weight)
    >>>         self.bias = bias
    >>>
    >>>     def __call__(self, x):
    >>>         return self.weight[...] * x + self.bias
    >>>
    >>> model = LinearModel(weight=1.0, bias=1.0)
    >>>
    >>> # (3) Define your loss function:
    >>> def mse(model, data):
    >>>     pred = model(data.X)
    >>>     return jnp.mean((pred - data.y) ** 2)
    >>>
    >>> # (4) Train!
    >>> trained_model, history = gpx.fit(
    >>>     model=model, objective=mse, train_data=D, optim=ox.sgd(0.001), num_iters=1000
    >>> )

Parameters:

  • model (Model) –

    The model Module to be optimised.

  • objective (Objective) –

    The objective function that we are optimising with respect to.

  • train_data (Dataset) –

    The training data to be used for the optimisation.

  • optim (GradientTransformation) –

    The Optax optimiser that is to be used for learning a parameter set.

  • trainable (Filter, default: Parameter ) –

    Filter to determine which parameters are trainable. Defaults to nnx.Param (all Parameter instances).

  • num_iters (int, default: 100 ) –

    The number of optimisation steps to run. Defaults to 100.

  • batch_size (int, default: -1 ) –

    The size of the mini-batch to use. Defaults to -1 (i.e. full batch).

  • key (KeyArray, default: key(42) ) –

    The random key to use for the optimisation batch selection. Defaults to jr.key(42).

  • log_rate (int, default: 10 ) –

    How frequently the objective function's value should be printed. Defaults to 10.

  • verbose (bool, default: True ) –

    Whether to print the training loading bar. Defaults to True.

  • unroll (int, default: 1 ) –

    The number of unrolled steps to use for the optimisation. Defaults to 1.

Returns:

  • tuple[Model, Array] –

    A tuple comprising the optimised model and training history.

fit_scipy

fit_scipy(
    *,
    model: Model,
    objective: Objective,
    train_data: Dataset,
    trainable: Filter = Parameter,
    max_iters: int = 500,
    verbose: bool = True,
    safe: bool = True,
) -> tuple[Model, Array]

Train a Module model with respect to a supplied Objective function using SciPy's L-BFGS-B optimiser.

Parameters are transformed to unconstrained space, flattened into a single vector, and passed to scipy.optimize.minimize. Gradients are computed via JAX's value_and_grad.

Parameters

model : Module The model to be optimised. objective : Objective The objective function to minimise with respect to the model parameters. train_data : Dataset The training data used to evaluate the objective. trainable : nnx.filterlib.Filter Filter selecting which parameters to optimise. Defaults to all Parameter instances. max_iters : int Maximum number of L-BFGS-B iterations. Defaults to 500. verbose : bool Whether to print optimisation progress. Defaults to True. safe : bool Whether to validate inputs before optimisation. Defaults to True.

Returns

tuple[Module, Array] A tuple of the optimised model and an array of objective values recorded at each iteration.

Example

import gpjax as gpx import jax.numpy as jnp

xtrain = jnp.linspace(0, 1).reshape(-1, 1) ytrain = jnp.sin(xtrain) D = gpx.Dataset(X=xtrain, y=ytrain)

meanf = gpx.mean_functions.Constant() kernel = gpx.kernels.RBF() likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) posterior = prior * likelihood

nmll = lambda p, d: -gpx.objectives.conjugate_mll(p, d) trained_model, history = gpx.fit_scipy( ... model=posterior, objective=nmll, train_data=D ... )

fit_lbfgs

fit_lbfgs(
    *,
    model: Model,
    objective: Objective,
    train_data: Dataset,
    params_bijection: dict[Parameter, Transform]
    | None = DEFAULT_BIJECTION,
    trainable: Filter = Parameter,
    max_iters: int = 100,
    safe: bool = True,
    max_linesearch_steps: int = 32,
    gtol: float = 1e-05,
) -> tuple[Model, jax.Array]

Train a Module model with respect to a supplied Objective function.

Uses Optax's L-BFGS implementation with a jax.lax.while_loop.

Parameters

model : Module The model to be optimised. objective : Objective The objective function to minimise. train_data : Dataset The training data used to evaluate the objective. params_bijection : dict[Parameter, Transform] | None Bijection used to transform parameters to unconstrained space. Defaults to DEFAULT_BIJECTION. trainable : nnx.filterlib.Filter Filter selecting which parameters to optimise. Defaults to all Parameter instances. max_iters : int Maximum number of L-BFGS iterations. Defaults to 100. safe : bool Whether to validate inputs before optimisation. Defaults to True. max_linesearch_steps : int Maximum number of line-search steps per iteration. Defaults to 32. gtol : float Terminate if the L2 norm of the gradient falls below this threshold. Defaults to 1e-5.

Returns

tuple[Module, Array] A tuple of the optimised model and the final loss value.

Example

import gpjax as gpx import jax.numpy as jnp

xtrain = jnp.linspace(0, 1).reshape(-1, 1) ytrain = jnp.sin(xtrain) D = gpx.Dataset(X=xtrain, y=ytrain)

meanf = gpx.mean_functions.Constant() kernel = gpx.kernels.RBF() likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) posterior = prior * likelihood

nmll = lambda p, d: -gpx.objectives.conjugate_mll(p, d) trained_model, final_loss = gpx.fit_lbfgs( ... model=posterior, objective=nmll, train_data=D ... )

get_batch

get_batch(
    train_data: Dataset, batch_size: int, key: KeyArray
) -> Dataset

Batch the data into mini-batches. Sampling is done with replacement.

Parameters:

  • train_data (Dataset) –

    The training dataset.

  • batch_size (int) –

    The batch size.

  • key (KeyArray) –

    The random key to use for the batch selection.

Example

import gpjax as gpx import jax.numpy as jnp import jax.random as jr

X = jnp.linspace(0, 1, 100).reshape(-1, 1) y = jnp.sin(X) D = gpx.Dataset(X=X, y=y)

from gpjax.fit import get_batch batch = get_batch(D, batch_size=16, key=jr.key(0))

Returns

Dataset The batched dataset.