Fit
fit
fit(
*,
model: Model,
objective: Objective,
train_data: Dataset,
optim: GradientTransformation,
key: KeyArray = jr.key(42),
num_iters: int = 100,
batch_size: int = -1,
log_rate: int = 10,
verbose: bool = True,
unroll: int = 1,
safe: bool = True,
) -> tuple[Model, jax.Array]
Train a Module model with respect to a supplied objective function. Optimisers used here should originate from Optax.
Example:
>>> import jax
>>> jax.config.update("jax_enable_x64", True)
>>> import jax.numpy as jnp
>>> import optax as ox
>>> import gpjax as gpx
>>>
>>> xtrain = jnp.linspace(0, 1, 50).reshape(-1, 1)
>>> ytrain = jnp.sin(xtrain)
>>> D = gpx.Dataset(X=xtrain, y=ytrain)
>>>
>>> meanf = gpx.mean_functions.Constant()
>>> kernel = gpx.kernels.RBF()
>>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
>>> prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
>>> posterior = prior * likelihood
>>>
>>> nmll = lambda p, d: -gpx.objectives.conjugate_mll(p, d)
>>> trained_model, history = gpx.fit(
... model=posterior, objective=nmll, train_data=D,
... optim=ox.adam(0.01), num_iters=100, verbose=False,
... )
Parameters:
-
model(Model) โThe model Module to be optimised.
-
objective(Objective) โThe objective function that we are optimising with respect to.
-
train_data(Dataset) โThe training data to be used for the optimisation.
-
optim(GradientTransformation) โThe Optax optimiser that is to be used for learning a parameter set.
-
num_iters(int, default:100) โThe number of optimisation steps to run. Defaults to 100.
-
batch_size(int, default:-1) โThe size of the mini-batch to use. Defaults to -1 (i.e. full batch).
-
key(KeyArray, default:key(42)) โThe random key to use for the optimisation batch selection. Defaults to jr.key(42).
-
log_rate(int, default:10) โHow frequently the objective function's value should be printed. Defaults to 10.
-
verbose(bool, default:True) โWhether to print the training loading bar. Defaults to True.
-
unroll(int, default:1) โThe number of unrolled steps to use for the optimisation. Defaults to 1.
Returns:
-
tuple[Model, Array]โA tuple comprising the optimised model and training history.
fit_scipy
fit_scipy(
*,
model: Model,
objective: Objective,
train_data: Dataset,
max_iters: int = 500,
verbose: bool = True,
safe: bool = True,
) -> tuple[Model, Array]
Train a Module model with respect to a supplied Objective function using SciPy's L-BFGS-B optimiser.
Parameters are transformed to unconstrained space, flattened into a
single vector, and passed to scipy.optimize.minimize. Gradients
are computed via JAX's value_and_grad.
Parameters
model : Module The model to be optimised. objective : Objective The objective function to minimise with respect to the model parameters. train_data : Dataset The training data used to evaluate the objective. max_iters : int Maximum number of L-BFGS-B iterations. Defaults to 500. verbose : bool Whether to print optimisation progress. Defaults to True. safe : bool Whether to validate inputs before optimisation. Defaults to True.
Returns
tuple[Module, Array] A tuple of the optimised model and an array of objective values recorded at each iteration.
Example
import jax jax.config.update("jax_enable_x64", True) import gpjax as gpx import jax.numpy as jnp
xtrain = jnp.linspace(0, 1).reshape(-1, 1) ytrain = jnp.sin(xtrain) D = gpx.Dataset(X=xtrain, y=ytrain)
meanf = gpx.mean_functions.Constant() kernel = gpx.kernels.RBF() likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) posterior = prior * likelihood
nmll = lambda p, d: -gpx.objectives.conjugate_mll(p, d) trained_model, history = gpx.fit_scipy( ... model=posterior, objective=nmll, train_data=D ... )
fit_lbfgs
fit_lbfgs(
*,
model: Model,
objective: Objective,
train_data: Dataset,
max_iters: int = 100,
safe: bool = True,
max_linesearch_steps: int = 32,
gtol: float = 1e-05,
) -> tuple[Model, jax.Array]
Train a Module model with respect to a supplied Objective function.
Uses Optax's L-BFGS implementation with a jax.lax.while_loop.
Parameters
model : Module The model to be optimised. objective : Objective The objective function to minimise. train_data : Dataset The training data used to evaluate the objective. max_iters : int Maximum number of L-BFGS iterations. Defaults to 100. safe : bool Whether to validate inputs before optimisation. Defaults to True. max_linesearch_steps : int Maximum number of line-search steps per iteration. Defaults to 32. gtol : float Terminate if the L2 norm of the gradient falls below this threshold. Defaults to 1e-5.
Returns
tuple[Module, Array] A tuple of the optimised model and the final loss value.
Example
import jax jax.config.update("jax_enable_x64", True) import gpjax as gpx import jax.numpy as jnp
xtrain = jnp.linspace(0, 1, 20).reshape(-1, 1) ytrain = jnp.sin(xtrain) D = gpx.Dataset(X=xtrain, y=ytrain)
meanf = gpx.mean_functions.Constant() kernel = gpx.kernels.RBF() likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) posterior = prior * likelihood
nmll = lambda p, d: -gpx.objectives.conjugate_mll(p, d) trained_model, final_loss = gpx.fit_lbfgs( ... model=posterior, objective=nmll, train_data=D ... )
get_batch
Batch the data into mini-batches. Sampling is done with replacement.
Parameters:
-
train_data(Dataset) โThe training dataset.
-
batch_size(int) โThe batch size.
-
key(KeyArray) โThe random key to use for the batch selection.
Example
import gpjax as gpx import jax.numpy as jnp import jax.random as jr
X = jnp.linspace(0, 1, 100).reshape(-1, 1) y = jnp.sin(X) D = gpx.Dataset(X=X, y=y)
from gpjax.fit import get_batch batch = get_batch(D, batch_size=16, key=jr.key(0))
Returns
Dataset The batched dataset.