Skip to content

Module

gpjax.base.module

__all__ = ['Module', 'meta_leaves', 'meta_flatten', 'meta_map', 'meta', 'static_field'] module-attribute
Self = TypeVar('Self') module-attribute
Module

Bases: Pytree

__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self

Replace the values of the fields of the object.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_meta(**kwargs: Any) -> Self

Replace the metadata of the fields.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the metadata of the fields of the object.

{}
Returns
Module: with the metadata of the fields replaced.
update_meta(**kwargs: Any) -> Self

Update the metadata of the fields. The metadata must already exist.

Parameters:

Name Type Description Default
**kwargs Any

keyword arguments to replace the fields of the object.

{}
Returns
Module: with the fields replaced.
replace_trainable(**kwargs: Dict[str, bool]) -> Self

Replace the trainability status of local nodes of the Module.

replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self

Replace the bijectors of local nodes of the Module.

constrain() -> Self

Transform model parameters to the constrained space according to their defined bijectors.

Returns
Module: transformed to the constrained space.
unconstrain() -> Self

Transform model parameters to the unconstrained space according to their defined bijectors.

Returns
Module: transformed to the unconstrained space.
stop_gradient() -> Self

Stop gradients flowing through the Module.

Returns
Module: with gradients stopped.
trainables() -> Self
static_field(default: Any = dataclasses.MISSING, *, default_factory: Any = dataclasses.MISSING, init: bool = True, repr: bool = True, hash: Optional[bool] = None, compare: bool = True, metadata: Optional[Mapping[str, Any]] = None)
meta_leaves(pytree: Module, *, is_leaf: Optional[Callable[[Any], bool]] = None) -> List[Tuple[Optional[Dict[str, Any]], Any]]

Returns the meta of the leaves of the pytree.

Parameters:

Name Type Description Default
pytree Module

pytree to get the meta of.

required
is_leaf Callable[[Any], bool]

predicate to determine if a node is a leaf. Defaults to None.

None
Returns
List[Tuple[Dict[str, Any], Any]]: meta of the leaves of the pytree.
meta_flatten(pytree: Union[Module, Any], *, is_leaf: Optional[Callable[[Any], bool]] = None) -> Union[Module, Any]

Returns the meta of the Module.

Parameters:

Name Type Description Default
pytree Module

Module to get the meta of.

required
is_leaf Callable[[Any], bool]

predicate to determine if a node is a leaf. Defaults to None.

None
Returns
Module: meta of the Module.
meta_map(f: Callable[[Any, Dict[str, Any]], Any], pytree: Union[Module, Any], *rest: Any, is_leaf: Optional[Callable[[Any], bool]] = None) -> Union[Module, Any]

Apply a function to a Module where the first argument are the pytree leaves, and the second argument are the Module metadata leaves. Args: f (Callable[[Any, Dict[str, Any]], Any]): The function to apply to the pytree. pytree (Module): The pytree to apply the function to. rest (Any, optional): Additional pytrees to apply the function to. Defaults to None. is_leaf (Callable[[Any], bool], optional): predicate to determine if a node is a leaf. Defaults to None.

Returns
Module: The transformed pytree.
meta(pytree: Module, *, is_leaf: Optional[Callable[[Any], bool]] = None) -> Module

Returns the metadata of the Module as a pytree.

Parameters:

Name Type Description Default
pytree Module

pytree to get the metadata of.

required
Returns
Module: metadata of the pytree.
save_tree(path: str, model: Module, overwrite: bool = False, iterate: int = None) -> None
load_tree(path: str, model: Module) -> Module