Numpyro Extras
tree_path_to_name
Convert a JAX tree path to a dotted parameter name.
As an example, the lengthscale parameter of an RBF kernel that was instantiated with the name "kernel" would then be registered with the name "kernel.lengthscale".
Parameters:
-
path(KeyPath) βA JAX tree path (sequence of path keys).
-
prefix(str, default:'') βOptional prefix to prepend to the name.
Returns:
-
strβA dotted string representing the parameter name.
resolve_prior
resolve_prior(
name: str,
param: Parameter,
priors: dict[str, Distribution],
) -> dist.Distribution | None
Resolve the prior precedence of a parameter.
Explicit priors in the dictionary take precedence over attached priors. This step allows for explicit prior specification in the model definition, and then overriding with a different prior during inference.
Parameters:
-
name(str) βThe parameter name.
-
param(Parameter) βThe Parameter instance.
-
priors(dict[str, Distribution]) βDictionary mapping parameter names to distributions.
Returns:
-
Distribution | NoneβThe resolved distribution, or None if no prior is found.
register_parameters
register_parameters(
model: Module,
priors: dict[str, Distribution] | None = None,
prefix: str = "",
) -> nnx.Module
Register GPJax parameters with Numpyro.
Parameters:
-
model(Module) βThe GPJax model that contains parameters and is a subclass of nnx.Module.
-
priors(dict[str, Distribution] | None, default:None) βOptional dictionary mapping parameter names to Numpyro distributions.
-
prefix(str, default:'') βOptional prefix for parameter names.
Returns:
-
ModuleβThe model with parameters updated from Numpyro samples.