Parameters
FillTriangularTransform
Bases: Transform
Transform that maps a vector of length n(n+1)/2 to an n x n lower triangular matrix. The ordering is assumed to be: (0,0), (1,0), (1,1), (2,0), (2,1), (2,2), ..., (n-1, n-1)
__call__
Parameter
NonNegativeReal
Bases: Parameter[T]
Parameter that is non-negative.
PositiveReal
Bases: Parameter[T]
Parameter that is strictly positive.
Real
Bases: Parameter[T]
Parameter that can take any real value.
SigmoidBounded
Bases: Parameter[T]
Parameter that is bounded between 0 and 1.
LowerTriangular
Bases: Parameter[T]
Parameter that is a lower triangular matrix.
CoregionalizationMatrix
Bases: Module
Parameterises a PSD output-correlation matrix B = WW^T + diag(kappa).
Parameters:
-
num_outputs(int) –Number of output dimensions (P).
-
rank(int) –Rank of the low-rank factor W. Controls expressiveness.
-
key(Array) –JAX PRNG key for W initialisation.
transform
transform(
params: State,
params_bijection: dict[str, Transform],
inverse: bool = False,
) -> nnx.State
Transforms parameters using a bijector.
Example
from gpjax.parameters import PositiveReal, transform import jax.numpy as jnp import numpyro.distributions.transforms as npt from flax import nnx params = nnx.State( ... { ... "a": PositiveReal(jnp.array([1.0])), ... "b": PositiveReal(jnp.array([2.0])), ... } ... ) params_bijection = {'positive': npt.SoftplusTransform()} transformed_params = transform(params, params_bijection) print(transformed_params["a"][...]) [1.3132617]
Parameters:
-
params(State) –A nnx.State object containing parameters to be transformed.
-
params_bijection(dict[str, Transform]) –A dictionary mapping parameter types to bijectors.
-
inverse(bool, default:False) –Whether to apply the inverse transformation.
Returns:
-
State(State) –A new nnx.State object containing the transformed parameters.