Fit
fit
fit(*, model, objective, train_data, optim, params_bijection=DEFAULT_BIJECTION, key=jr.PRNGKey(42), num_iters=100, batch_size=-1, log_rate=10, verbose=True, unroll=1, safe=True)
Train a Module model with respect to a supplied objective function. Optimisers used here should originate from Optax.
Example:
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> import optax as ox
>>> import gpjax as gpx
>>> from gpjax.parameters import PositiveReal, Static
>>>
>>> # (1) Create a dataset:
>>> X = jnp.linspace(0.0, 10.0, 100)[:, None]
>>> y = 2.0 * X + 1.0 + 10 * jr.normal(jr.PRNGKey(0), X.shape)
>>> D = gpx.Dataset(X, y)
>>> # (2) Define your model:
>>> class LinearModel(nnx.Module):
>>> def __init__(self, weight: float, bias: float):
>>> self.weight = PositiveReal(weight)
>>> self.bias = Static(bias)
>>>
>>> def __call__(self, x):
>>> return self.weight.value * x + self.bias.value
>>>
>>> model = LinearModel(weight=1.0, bias=1.0)
>>>
>>> # (3) Define your loss function:
>>> def mse(model, data):
>>> pred = model(data.X)
>>> return jnp.mean((pred - data.y) ** 2)
>>>
>>> # (4) Train!
>>> trained_model, history = gpx.fit(
>>> model=model, objective=mse, train_data=D, optim=ox.sgd(0.001), num_iters=1000
>>> )
Parameters:
-
model
(Model
) βThe model Module to be optimised.
-
objective
(Objective
) βThe objective function that we are optimising with respect to.
-
train_data
(Dataset
) βThe training data to be used for the optimisation.
-
optim
(GradientTransformation
) βThe Optax optimiser that is to be used for learning a parameter set.
-
num_iters
(int
, default:100
) βThe number of optimisation steps to run. Defaults to 100.
-
batch_size
(int
, default:-1
) βThe size of the mini-batch to use. Defaults to -1 (i.e. full batch).
-
key
(KeyArray
, default:PRNGKey(42)
) βThe random key to use for the optimisation batch selection. Defaults to jr.PRNGKey(42).
-
log_rate
(int
, default:10
) βHow frequently the objective function's value should be printed. Defaults to 10.
-
verbose
(bool
, default:True
) βWhether to print the training loading bar. Defaults to True.
-
unroll
(int
, default:1
) βThe number of unrolled steps to use for the optimisation. Defaults to 1.
Returns:
-
tuple[Model, Array]
βA tuple comprising the optimised model and training history.
fit_scipy
Train a Module model with respect to a supplied Objective function. Optimisers used here should originate from Optax. todo
Parameters:
-
model
(Model
) βthe model Module to be optimised.
-
objective
(Objective
) βThe objective function that we are optimising with respect to.
-
train_data
(Dataset
) βThe training data to be used for the optimisation.
-
max_iters
(int
, default:500
) βThe maximum number of optimisation steps to run. Defaults to 500.
-
verbose
(bool
, default:True
) βWhether to print the information about the optimisation. Defaults to True.
Returns:
-
tuple[Model, Array]
βA tuple comprising the optimised model and training history.
get_batch
Batch the data into mini-batches. Sampling is done with replacement.
Parameters:
-
train_data
(Dataset
) βThe training dataset.
-
batch_size
(int
) βThe batch size.
-
key
(KeyArray
) βThe random key to use for the batch selection.
Returns
Dataset: The batched dataset.