Skip to content

Fit

fit

fit(*, model, objective, train_data, optim, params_bijection=DEFAULT_BIJECTION, key=jr.PRNGKey(42), num_iters=100, batch_size=-1, log_rate=10, verbose=True, unroll=1, safe=True)

Train a Module model with respect to a supplied objective function. Optimisers used here should originate from Optax.

Example:

    >>> import jax.numpy as jnp
    >>> import jax.random as jr
    >>> import optax as ox
    >>> import gpjax as gpx
    >>> from gpjax.parameters import PositiveReal, Static
    >>>
    >>> # (1) Create a dataset:
    >>> X = jnp.linspace(0.0, 10.0, 100)[:, None]
    >>> y = 2.0 * X + 1.0 + 10 * jr.normal(jr.PRNGKey(0), X.shape)
    >>> D = gpx.Dataset(X, y)
    >>> # (2) Define your model:
    >>> class LinearModel(nnx.Module):
    >>>     def __init__(self, weight: float, bias: float):
    >>>         self.weight = PositiveReal(weight)
    >>>         self.bias = Static(bias)
    >>>
    >>>     def __call__(self, x):
    >>>         return self.weight.value * x + self.bias.value
    >>>
    >>> model = LinearModel(weight=1.0, bias=1.0)
    >>>
    >>> # (3) Define your loss function:
    >>> def mse(model, data):
    >>>     pred = model(data.X)
    >>>     return jnp.mean((pred - data.y) ** 2)
    >>>
    >>> # (4) Train!
    >>> trained_model, history = gpx.fit(
    >>>     model=model, objective=mse, train_data=D, optim=ox.sgd(0.001), num_iters=1000
    >>> )

Parameters:

  • model (Model) –

    The model Module to be optimised.

  • objective (Objective) –

    The objective function that we are optimising with respect to.

  • train_data (Dataset) –

    The training data to be used for the optimisation.

  • optim (GradientTransformation) –

    The Optax optimiser that is to be used for learning a parameter set.

  • num_iters (int, default: 100 ) –

    The number of optimisation steps to run. Defaults to 100.

  • batch_size (int, default: -1 ) –

    The size of the mini-batch to use. Defaults to -1 (i.e. full batch).

  • key (KeyArray, default: PRNGKey(42) ) –

    The random key to use for the optimisation batch selection. Defaults to jr.PRNGKey(42).

  • log_rate (int, default: 10 ) –

    How frequently the objective function's value should be printed. Defaults to 10.

  • verbose (bool, default: True ) –

    Whether to print the training loading bar. Defaults to True.

  • unroll (int, default: 1 ) –

    The number of unrolled steps to use for the optimisation. Defaults to 1.

Returns:

  • tuple[Model, Array] –

    A tuple comprising the optimised model and training history.

fit_scipy

fit_scipy(*, model, objective, train_data, max_iters=500, verbose=True, safe=True)

Train a Module model with respect to a supplied Objective function. Optimisers used here should originate from Optax. todo

Parameters:

  • model (Model) –

    the model Module to be optimised.

  • objective (Objective) –

    The objective function that we are optimising with respect to.

  • train_data (Dataset) –

    The training data to be used for the optimisation.

  • max_iters (int, default: 500 ) –

    The maximum number of optimisation steps to run. Defaults to 500.

  • verbose (bool, default: True ) –

    Whether to print the information about the optimisation. Defaults to True.

Returns:

  • tuple[Model, Array] –

    A tuple comprising the optimised model and training history.

get_batch

get_batch(train_data, batch_size, key)

Batch the data into mini-batches. Sampling is done with replacement.

Parameters:

  • train_data (Dataset) –

    The training dataset.

  • batch_size (int) –

    The batch size.

  • key (KeyArray) –

    The random key to use for the batch selection.

Returns

Dataset: The batched dataset.