Parameters
FillTriangularTransform
Bases: Transform
Transform that maps a vector of length n(n+1)/2 to an n x n lower triangular matrix. The ordering is assumed to be: (0,0), (1,0), (1,1), (2,0), (2,1), (2,2), ..., (n-1, n-1)
__call__
Parameter
NonNegativeReal
Bases: Parameter[T]
Parameter that is non-negative.
PositiveReal
Bases: Parameter[T]
Parameter that is strictly positive.
Real
Bases: Parameter[T]
Parameter that can take any real value.
SigmoidBounded
Bases: Parameter[T]
Parameter that is bounded between 0 and 1.
LowerTriangular
Bases: Parameter[T]
Parameter that is a lower triangular matrix.
CoregionalizationMatrix
Bases: Module
Parameterises a PSD output-correlation matrix B = WW^T + diag(kappa).
Parameters:
-
num_outputs(int) βNumber of output dimensions (P).
-
rank(int) βRank of the low-rank factor W. Controls expressiveness.
-
key(Array) βJAX PRNG key for W initialisation.
transform
transform(
params: State,
params_bijection: dict[str, Transform],
inverse: bool = False,
) -> nnx.State
Transforms parameters using a bijector.
Example
from gpjax.parameters import PositiveReal, transform import jax.numpy as jnp import numpyro.distributions.transforms as npt from flax import nnx params = nnx.State( ... { ... "a": PositiveReal(jnp.array([1.0])), ... "b": PositiveReal(jnp.array([2.0])), ... } ... ) params_bijection = {'positive': npt.SoftplusTransform()} transformed_params = transform(params, params_bijection) print(transformed_params["a"][...]) [1.3132617]
Parameters:
-
params(State) βA nnx.State object containing parameters to be transformed.
-
params_bijection(dict[str, Transform]) βA dictionary mapping parameter types to bijectors.
-
inverse(bool, default:False) βWhether to apply the inverse transformation.
Returns:
-
State(State) βA new nnx.State object containing the transformed parameters.