Skip to content

Fit

gpjax.fit

ModuleModel = TypeVar('ModuleModel', bound=Module) module-attribute
__all__ = ['fit', 'get_batch'] module-attribute
fit(*, model: ModuleModel, objective: Union[AbstractObjective, Callable[[ModuleModel, Dataset], ScalarFloat]], train_data: Dataset, optim: ox.GradientTransformation, key: KeyArray, num_iters: Optional[int] = 100, batch_size: Optional[int] = -1, log_rate: Optional[int] = 10, verbose: Optional[bool] = True, unroll: Optional[int] = 1, safe: Optional[bool] = True) -> Tuple[ModuleModel, Array]

Train a Module model with respect to a supplied Objective function. Optimisers used here should originate from Optax.

Example:

    >>> import jax.numpy as jnp
    >>> import jax.random as jr
    >>> import optax as ox
    >>> import gpjax as gpx
    >>>
    >>> # (1) Create a dataset:
    >>> X = jnp.linspace(0.0, 10.0, 100)[:, None]
    >>> y = 2.0 * X + 1.0 + 10 * jr.normal(jr.key(0), X.shape)
    >>> D = gpx.Dataset(X, y)
    >>>
    >>> # (2) Define your model:
    >>> class LinearModel(gpx.base.Module):
            weight: float = gpx.base.param_field()
            bias: float = gpx.base.param_field()

            def __call__(self, x):
                return self.weight * x + self.bias

    >>> model = LinearModel(weight=1.0, bias=1.0)
    >>>
    >>> # (3) Define your loss function:
    >>> class MeanSquareError(gpx.objectives.AbstractObjective):
            def evaluate(self, model: LinearModel, train_data: gpx.Dataset) -> float:
                return jnp.mean((train_data.y - model(train_data.X)) ** 2)
    >>>
    >>> loss = MeanSqaureError()
    >>>
    >>> # (4) Train!
    >>> trained_model, history = gpx.fit(
            model=model, objective=loss, train_data=D, optim=ox.sgd(0.001), num_iters=1000
        )

Parameters:

Name Type Description Default
model Module

The model Module to be optimised.

required
objective Objective

The objective function that we are optimising with respect to.

required
train_data Dataset

The training data to be used for the optimisation.

required
optim GradientTransformation

The Optax optimiser that is to be used for learning a parameter set.

required
num_iters Optional[int]

The number of optimisation steps to run. Defaults to 100.

100
batch_size Optional[int]

The size of the mini-batch to use. Defaults to -1 (i.e. full batch).

-1
key Optional[KeyArray]

The random key to use for the optimisation batch selection. Defaults to jr.key(42).

required
log_rate Optional[int]

How frequently the objective function's value should be printed. Defaults to 10.

10
verbose Optional[bool]

Whether to print the training loading bar. Defaults to True.

True
unroll int

The number of unrolled steps to use for the optimisation. Defaults to 1.

1
Returns
Tuple[Module, Array]: A Tuple comprising the optimised model and training
    history respectively.
fit_scipy(*, model: ModuleModel, objective: Union[AbstractObjective, Callable[[ModuleModel, Dataset], ScalarFloat]], train_data: Dataset, max_iters: Optional[int] = 500, verbose: Optional[bool] = True, safe: Optional[bool] = True) -> Tuple[ModuleModel, Array]

Train a Module model with respect to a supplied Objective function. Optimisers used here should originate from Optax. todo

Parameters:

Name Type Description Default
model Module

The model Module to be optimised.

required
objective Objective

The objective function that we are optimising with respect to.

required
train_data Dataset

The training data to be used for the optimisation.

required
max_iters Optional[int]

The maximum number of optimisation steps to run. Defaults to 500.

500
verbose Optional[bool]

Whether to print the information about the optimisation. Defaults to True.

True
Returns
Tuple[Module, Array]: A Tuple comprising the optimised model and training
    history respectively.
get_batch(train_data: Dataset, batch_size: int, key: KeyArray) -> Dataset

Batch the data into mini-batches. Sampling is done with replacement.

Parameters:

Name Type Description Default
train_data Dataset

The training dataset.

required
batch_size int

The batch size.

required
key KeyArray

The random key to use for the batch selection.

required
Returns
Dataset: The batched dataset.