Fit
fit
fit(
*,
model: Model,
objective: Objective,
train_data: Dataset,
optim: GradientTransformation,
params_bijection: dict[Parameter, Transform]
| None = DEFAULT_BIJECTION,
trainable: Filter = Parameter,
key: KeyArray = jr.key(42),
num_iters: int = 100,
batch_size: int = -1,
log_rate: int = 10,
verbose: bool = True,
unroll: int = 1,
safe: bool = True,
) -> tuple[Model, jax.Array]
Train a Module model with respect to a supplied objective function. Optimisers used here should originate from Optax.
Example:
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> import optax as ox
>>> import gpjax as gpx
>>> from gpjax.parameters import PositiveReal
>>>
>>> # (1) Create a dataset:
>>> X = jnp.linspace(0.0, 10.0, 100)[:, None]
>>> y = 2.0 * X + 1.0 + 10 * jr.normal(jr.key(0), X.shape)
>>> D = gpx.Dataset(X, y)
>>> # (2) Define your model:
>>> class LinearModel(nnx.Module):
>>> def __init__(self, weight: float, bias: float):
>>> self.weight = PositiveReal(weight)
>>> self.bias = bias
>>>
>>> def __call__(self, x):
>>> return self.weight[...] * x + self.bias
>>>
>>> model = LinearModel(weight=1.0, bias=1.0)
>>>
>>> # (3) Define your loss function:
>>> def mse(model, data):
>>> pred = model(data.X)
>>> return jnp.mean((pred - data.y) ** 2)
>>>
>>> # (4) Train!
>>> trained_model, history = gpx.fit(
>>> model=model, objective=mse, train_data=D, optim=ox.sgd(0.001), num_iters=1000
>>> )
Parameters:
-
model(Model) โThe model Module to be optimised.
-
objective(Objective) โThe objective function that we are optimising with respect to.
-
train_data(Dataset) โThe training data to be used for the optimisation.
-
optim(GradientTransformation) โThe Optax optimiser that is to be used for learning a parameter set.
-
trainable(Filter, default:Parameter) โFilter to determine which parameters are trainable. Defaults to nnx.Param (all Parameter instances).
-
num_iters(int, default:100) โThe number of optimisation steps to run. Defaults to 100.
-
batch_size(int, default:-1) โThe size of the mini-batch to use. Defaults to -1 (i.e. full batch).
-
key(KeyArray, default:key(42)) โThe random key to use for the optimisation batch selection. Defaults to jr.key(42).
-
log_rate(int, default:10) โHow frequently the objective function's value should be printed. Defaults to 10.
-
verbose(bool, default:True) โWhether to print the training loading bar. Defaults to True.
-
unroll(int, default:1) โThe number of unrolled steps to use for the optimisation. Defaults to 1.
Returns:
-
tuple[Model, Array]โA tuple comprising the optimised model and training history.
fit_scipy
fit_scipy(
*,
model: Model,
objective: Objective,
train_data: Dataset,
trainable: Filter = Parameter,
max_iters: int = 500,
verbose: bool = True,
safe: bool = True,
) -> tuple[Model, Array]
Train a Module model with respect to a supplied Objective function using SciPy's L-BFGS-B optimiser.
Parameters are transformed to unconstrained space, flattened into a
single vector, and passed to scipy.optimize.minimize. Gradients
are computed via JAX's value_and_grad.
Parameters
model : Module
The model to be optimised.
objective : Objective
The objective function to minimise with respect to the model
parameters.
train_data : Dataset
The training data used to evaluate the objective.
trainable : nnx.filterlib.Filter
Filter selecting which parameters to optimise. Defaults to all
Parameter instances.
max_iters : int
Maximum number of L-BFGS-B iterations. Defaults to 500.
verbose : bool
Whether to print optimisation progress. Defaults to True.
safe : bool
Whether to validate inputs before optimisation. Defaults to True.
Returns
tuple[Module, Array] A tuple of the optimised model and an array of objective values recorded at each iteration.
Example
import gpjax as gpx import jax.numpy as jnp
xtrain = jnp.linspace(0, 1).reshape(-1, 1) ytrain = jnp.sin(xtrain) D = gpx.Dataset(X=xtrain, y=ytrain)
meanf = gpx.mean_functions.Constant() kernel = gpx.kernels.RBF() likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) posterior = prior * likelihood
nmll = lambda p, d: -gpx.objectives.conjugate_mll(p, d) trained_model, history = gpx.fit_scipy( ... model=posterior, objective=nmll, train_data=D ... )
fit_lbfgs
fit_lbfgs(
*,
model: Model,
objective: Objective,
train_data: Dataset,
params_bijection: dict[Parameter, Transform]
| None = DEFAULT_BIJECTION,
trainable: Filter = Parameter,
max_iters: int = 100,
safe: bool = True,
max_linesearch_steps: int = 32,
gtol: float = 1e-05,
) -> tuple[Model, jax.Array]
Train a Module model with respect to a supplied Objective function.
Uses Optax's L-BFGS implementation with a jax.lax.while_loop.
Parameters
model : Module
The model to be optimised.
objective : Objective
The objective function to minimise.
train_data : Dataset
The training data used to evaluate the objective.
params_bijection : dict[Parameter, Transform] | None
Bijection used to transform parameters to unconstrained space.
Defaults to DEFAULT_BIJECTION.
trainable : nnx.filterlib.Filter
Filter selecting which parameters to optimise. Defaults to all
Parameter instances.
max_iters : int
Maximum number of L-BFGS iterations. Defaults to 100.
safe : bool
Whether to validate inputs before optimisation. Defaults to True.
max_linesearch_steps : int
Maximum number of line-search steps per iteration. Defaults to 32.
gtol : float
Terminate if the L2 norm of the gradient falls below this
threshold. Defaults to 1e-5.
Returns
tuple[Module, Array] A tuple of the optimised model and the final loss value.
Example
import gpjax as gpx import jax.numpy as jnp
xtrain = jnp.linspace(0, 1).reshape(-1, 1) ytrain = jnp.sin(xtrain) D = gpx.Dataset(X=xtrain, y=ytrain)
meanf = gpx.mean_functions.Constant() kernel = gpx.kernels.RBF() likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) posterior = prior * likelihood
nmll = lambda p, d: -gpx.objectives.conjugate_mll(p, d) trained_model, final_loss = gpx.fit_lbfgs( ... model=posterior, objective=nmll, train_data=D ... )
get_batch
Batch the data into mini-batches. Sampling is done with replacement.
Parameters:
-
train_data(Dataset) โThe training dataset.
-
batch_size(int) โThe batch size.
-
key(KeyArray) โThe random key to use for the batch selection.
Example
import gpjax as gpx import jax.numpy as jnp import jax.random as jr
X = jnp.linspace(0, 1, 100).reshape(-1, 1) y = jnp.sin(X) D = gpx.Dataset(X=X, y=y)
from gpjax.fit import get_batch batch = get_batch(D, batch_size=16, key=jr.key(0))
Returns
Dataset The batched dataset.