Skip to content

Parameters

Parameter

Parameter(value, tag, **kwargs)

Bases: Variable[T]

Parameter base class.

All trainable parameters in GPJax should inherit from this class.

PositiveReal

PositiveReal(value, tag='positive', **kwargs)

Bases: Parameter[T]

Parameter that is strictly positive.

Real

Real(value, tag='real', **kwargs)

Bases: Parameter[T]

Parameter that can take any real value.

SigmoidBounded

SigmoidBounded(value, tag='sigmoid', **kwargs)

Bases: Parameter[T]

Parameter that is bounded between 0 and 1.

Static

Static(value, tag='static', **kwargs)

Bases: Variable[T]

Static parameter that is not trainable.

LowerTriangular

LowerTriangular(value, tag='lower_triangular', **kwargs)

Bases: Parameter[T]

Parameter that is a lower triangular matrix.

transform

transform(params, params_bijection, inverse=False)

Transforms parameters using a bijector.

Example:

    >>> from gpjax.parameters import PositiveReal, transform
    >>> import jax.numpy as jnp
    >>> import tensorflow_probability.substrates.jax.bijectors as tfb
    >>> from flax import nnx
    >>> params = nnx.State(
    >>>     {
    >>>         "a": PositiveReal(jnp.array([1.0])),
    >>>         "b": PositiveReal(jnp.array([2.0])),
    >>>     }
    >>> )
    >>> params_bijection = {'positive': tfb.Softplus()}
    >>> transformed_params = transform(params, params_bijection)
    >>> print(transformed_params["a"].value)
     [1.3132617]

Parameters:

  • params (State) –

    A nnx.State object containing parameters to be transformed.

  • params_bijection (Dict[str, Bijector]) –

    A dictionary mapping parameter types to bijectors.

  • inverse (bool, default: False ) –

    Whether to apply the inverse transformation.

Returns:

  • State ( State ) –

    A new nnx.State object containing the transformed parameters.