Fit
gpjax.fit
ModuleModel = TypeVar('ModuleModel', bound=Module)
module-attribute
__all__ = ['fit', 'get_batch']
module-attribute
fit(*, model: ModuleModel, objective: Union[AbstractObjective, Callable[[ModuleModel, Dataset], ScalarFloat]], train_data: Dataset, optim: ox.GradientTransformation, key: KeyArray, num_iters: Optional[int] = 100, batch_size: Optional[int] = -1, log_rate: Optional[int] = 10, verbose: Optional[bool] = True, unroll: Optional[int] = 1, safe: Optional[bool] = True) -> Tuple[ModuleModel, Array]
Train a Module model with respect to a supplied Objective function. Optimisers used here should originate from Optax.
Example:
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> import optax as ox
>>> import gpjax as gpx
>>>
>>> # (1) Create a dataset:
>>> X = jnp.linspace(0.0, 10.0, 100)[:, None]
>>> y = 2.0 * X + 1.0 + 10 * jr.normal(jr.PRNGKey(0), X.shape)
>>> D = gpx.Dataset(X, y)
>>>
>>> # (2) Define your model:
>>> class LinearModel(gpx.base.Module):
weight: float = gpx.base.param_field()
bias: float = gpx.base.param_field()
def __call__(self, x):
return self.weight * x + self.bias
>>> model = LinearModel(weight=1.0, bias=1.0)
>>>
>>> # (3) Define your loss function:
>>> class MeanSquareError(gpx.objectives.AbstractObjective):
def evaluate(self, model: LinearModel, train_data: gpx.Dataset) -> float:
return jnp.mean((train_data.y - model(train_data.X)) ** 2)
>>>
>>> loss = MeanSqaureError()
>>>
>>> # (4) Train!
>>> trained_model, history = gpx.fit(
model=model, objective=loss, train_data=D, optim=ox.sgd(0.001), num_iters=1000
)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Module
|
The model Module to be optimised. |
required |
objective |
Objective
|
The objective function that we are optimising with respect to. |
required |
train_data |
Dataset
|
The training data to be used for the optimisation. |
required |
optim |
GradientTransformation
|
The Optax optimiser that is to be used for learning a parameter set. |
required |
num_iters |
Optional[int]
|
The number of optimisation steps to run. Defaults to 100. |
100
|
batch_size |
Optional[int]
|
The size of the mini-batch to use. Defaults to -1 (i.e. full batch). |
-1
|
key |
Optional[KeyArray]
|
The random key to use for the optimisation batch selection. Defaults to jr.PRNGKey(42). |
required |
log_rate |
Optional[int]
|
How frequently the objective function's value should be printed. Defaults to 10. |
10
|
verbose |
Optional[bool]
|
Whether to print the training loading bar. Defaults to True. |
True
|
unroll |
int
|
The number of unrolled steps to use for the optimisation. Defaults to 1. |
1
|
Returns
Tuple[Module, Array]: A Tuple comprising the optimised model and training
history respectively.
fit_scipy(*, model: ModuleModel, objective: Union[AbstractObjective, Callable[[ModuleModel, Dataset], ScalarFloat]], train_data: Dataset, max_iters: Optional[int] = 500, verbose: Optional[bool] = True, safe: Optional[bool] = True) -> Tuple[ModuleModel, Array]
Train a Module model with respect to a supplied Objective function. Optimisers used here should originate from Optax. todo
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Module
|
The model Module to be optimised. |
required |
objective |
Objective
|
The objective function that we are optimising with respect to. |
required |
train_data |
Dataset
|
The training data to be used for the optimisation. |
required |
max_iters |
Optional[int]
|
The maximum number of optimisation steps to run. Defaults to 500. |
500
|
verbose |
Optional[bool]
|
Whether to print the information about the optimisation. Defaults to True. |
True
|
Returns
Tuple[Module, Array]: A Tuple comprising the optimised model and training
history respectively.
get_batch(train_data: Dataset, batch_size: int, key: KeyArray) -> Dataset
Batch the data into mini-batches. Sampling is done with replacement.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
train_data |
Dataset
|
The training dataset. |
required |
batch_size |
int
|
The batch size. |
required |
key |
KeyArray
|
The random key to use for the batch selection. |
required |
Returns
Dataset: The batched dataset.