Skip to content

Parameters

FillTriangularTransform

Bases: Transform

Transform that maps a vector of length n(n+1)/2 to an n x n lower triangular matrix. The ordering is assumed to be: (0,0), (1,0), (1,1), (2,0), (2,1), (2,2), ..., (n-1, n-1)

__call__

__call__(x)

Forward transformation.

Parameters

x : array_like, shape (..., L) Input vector with L = n(n+1)/2 for some integer n.

Returns

y : array_like, shape (..., n, n) Lower-triangular matrix (with zeros in the upper triangle) filled in row-major order (i.e. [ (0,0), (1,0), (1,1), ... ]).

Parameter

Parameter(
    value: T,
    tag: ParameterTag,
    prior: Distribution | None = None,
    **kwargs,
)

Bases: Variable[T]

Parameter base class.

All trainable parameters in GPJax should inherit from this class.

tag property

tag: ParameterTag

Return the parameter's constraint tag.

NonNegativeReal

NonNegativeReal(
    value: T, tag: ParameterTag = "non_negative", **kwargs
)

Bases: Parameter[T]

Parameter that is non-negative.

tag property

tag: ParameterTag

Return the parameter's constraint tag.

PositiveReal

PositiveReal(
    value: T, tag: ParameterTag = "positive", **kwargs
)

Bases: Parameter[T]

Parameter that is strictly positive.

tag property

tag: ParameterTag

Return the parameter's constraint tag.

Real

Real(value: T, tag: ParameterTag = 'real', **kwargs)

Bases: Parameter[T]

Parameter that can take any real value.

tag property

tag: ParameterTag

Return the parameter's constraint tag.

SigmoidBounded

SigmoidBounded(
    value: T, tag: ParameterTag = "sigmoid", **kwargs
)

Bases: Parameter[T]

Parameter that is bounded between 0 and 1.

tag property

tag: ParameterTag

Return the parameter's constraint tag.

LowerTriangular

LowerTriangular(
    value: T,
    tag: ParameterTag = "lower_triangular",
    **kwargs,
)

Bases: Parameter[T]

Parameter that is a lower triangular matrix.

tag property

tag: ParameterTag

Return the parameter's constraint tag.

CoregionalizationMatrix

CoregionalizationMatrix(
    num_outputs: int, rank: int, key: Array
)

Bases: Module

Parameterises a PSD output-correlation matrix B = WW^T + diag(kappa).

Parameters:

  • num_outputs (int) –

    Number of output dimensions (P).

  • rank (int) –

    Rank of the low-rank factor W. Controls expressiveness.

  • key (Array) –

    JAX PRNG key for W initialisation.

B property

B: ndarray

PSD coregionalization matrix [P, P].

transform

transform(
    params: State,
    params_bijection: dict[str, Transform],
    inverse: bool = False,
) -> nnx.State

Transforms parameters using a bijector.

Example

from gpjax.parameters import PositiveReal, transform import jax.numpy as jnp import numpyro.distributions.transforms as npt from flax import nnx params = nnx.State( ... { ... "a": PositiveReal(jnp.array([1.0])), ... "b": PositiveReal(jnp.array([2.0])), ... } ... ) params_bijection = {'positive': npt.SoftplusTransform()} transformed_params = transform(params, params_bijection) print(transformed_params["a"][...]) [1.3132617]

Parameters:

  • params (State) –

    A nnx.State object containing parameters to be transformed.

  • params_bijection (dict[str, Transform]) –

    A dictionary mapping parameter types to bijectors.

  • inverse (bool, default: False ) –

    Whether to apply the inverse transformation.

Returns:

  • State ( State ) –

    A new nnx.State object containing the transformed parameters.