Parameters
Parameter
Bases: Variable[T]
Parameter base class.
All trainable parameters in GPJax should inherit from this class.
PositiveReal
SigmoidBounded
Static
Bases: Variable[T]
Static parameter that is not trainable.
LowerTriangular
transform
Transforms parameters using a bijector.
Example:
>>> from gpjax.parameters import PositiveReal, transform
>>> import jax.numpy as jnp
>>> import tensorflow_probability.substrates.jax.bijectors as tfb
>>> from flax import nnx
>>> params = nnx.State(
>>> {
>>> "a": PositiveReal(jnp.array([1.0])),
>>> "b": PositiveReal(jnp.array([2.0])),
>>> }
>>> )
>>> params_bijection = {'positive': tfb.Softplus()}
>>> transformed_params = transform(params, params_bijection)
>>> print(transformed_params["a"].value)
[1.3132617]
Parameters:
-
params
(State
) βA nnx.State object containing parameters to be transformed.
-
params_bijection
(Dict[str, Bijector]
) βA dictionary mapping parameter types to bijectors.
-
inverse
(bool
, default:False
) βWhether to apply the inverse transformation.
Returns:
-
State
(State
) βA new nnx.State object containing the transformed parameters.