Skip to content

Numpyro Extras

tree_path_to_name

tree_path_to_name(path: KeyPath, prefix: str = '') -> str

Convert a JAX tree path to a dotted parameter name.

As an example, the lengthscale parameter of an RBF kernel that was instantiated with the name "kernel" would then be registered with the name "kernel.lengthscale".

Parameters:

  • path (KeyPath) –

    A JAX tree path (sequence of path keys).

  • prefix (str, default: '' ) –

    Optional prefix to prepend to the name.

Returns:

  • str –

    A dotted string representing the parameter name.

resolve_prior

resolve_prior(
    name: str,
    param: Parameter,
    priors: dict[str, Distribution],
) -> dist.Distribution | None

Resolve the prior precedence of a parameter.

Explicit priors in the dictionary take precedence over attached priors. This step allows for explicit prior specification in the model definition, and then overriding with a different prior during inference.

Parameters:

  • name (str) –

    The parameter name.

  • param (Parameter) –

    The Parameter instance.

  • priors (dict[str, Distribution]) –

    Dictionary mapping parameter names to distributions.

Returns:

  • Distribution | None –

    The resolved distribution, or None if no prior is found.

register_parameters

register_parameters(
    model: Module,
    priors: dict[str, Distribution] | None = None,
    prefix: str = "",
) -> nnx.Module

Register GPJax parameters with Numpyro.

Parameters:

  • model (Module) –

    The GPJax model that contains parameters and is a subclass of nnx.Module.

  • priors (dict[str, Distribution] | None, default: None ) –

    Optional dictionary mapping parameter names to Numpyro distributions.

  • prefix (str, default: '' ) –

    Optional prefix for parameter names.

Returns:

  • Module –

    The model with parameters updated from Numpyro samples.