Train a Module model with respect to a supplied Objective function
using SciPy's L-BFGS-B optimiser.
Parameters are transformed to unconstrained space, flattened into a
single vector, and passed to scipy.optimize.minimize. Gradients
are computed via JAX's value_and_grad.
Parameters
model : Module
The model to be optimised.
objective : Objective
The objective function to minimise with respect to the model
parameters.
train_data : Dataset
The training data used to evaluate the objective.
trainable : nnx.filterlib.Filter
Filter selecting which parameters to optimise. Defaults to all
Parameter instances.
max_iters : int
Maximum number of L-BFGS-B iterations. Defaults to 500.
verbose : bool
Whether to print optimisation progress. Defaults to True.
safe : bool
Whether to validate inputs before optimisation. Defaults to True.
Returns
tuple[Module, Array]
A tuple of the optimised model and an array of objective values
recorded at each iteration.
Train a Module model with respect to a supplied Objective function.
Uses Optax's L-BFGS implementation with a jax.lax.while_loop.
Parameters
model : Module
The model to be optimised.
objective : Objective
The objective function to minimise.
train_data : Dataset
The training data used to evaluate the objective.
params_bijection : dict[Parameter, Transform] | None
Bijection used to transform parameters to unconstrained space.
Defaults to DEFAULT_BIJECTION.
trainable : nnx.filterlib.Filter
Filter selecting which parameters to optimise. Defaults to all
Parameter instances.
max_iters : int
Maximum number of L-BFGS iterations. Defaults to 100.
safe : bool
Whether to validate inputs before optimisation. Defaults to True.
max_linesearch_steps : int
Maximum number of line-search steps per iteration. Defaults to 32.
gtol : float
Terminate if the L2 norm of the gradient falls below this
threshold. Defaults to 1e-5.
Returns
tuple[Module, Array]
A tuple of the optimised model and the final loss value.