Module
gpjax.base.module
__all__ = ['Module', 'meta_leaves', 'meta_flatten', 'meta_map', 'meta', 'static_field']
module-attribute
Self = TypeVar('Self')
module-attribute
Module
Bases: Pytree
__init_subclass__(mutable: bool = False)
replace(**kwargs: Any) -> Self
replace_meta(**kwargs: Any) -> Self
update_meta(**kwargs: Any) -> Self
replace_trainable(**kwargs: Dict[str, bool]) -> Self
Replace the trainability status of local nodes of the Module.
replace_bijector(**kwargs: Dict[str, tfb.Bijector]) -> Self
Replace the bijectors of local nodes of the Module.
constrain() -> Self
unconstrain() -> Self
stop_gradient() -> Self
trainables() -> Self
static_field(default: Any = dataclasses.MISSING, *, default_factory: Any = dataclasses.MISSING, init: bool = True, repr: bool = True, hash: Optional[bool] = None, compare: bool = True, metadata: Optional[Mapping[str, Any]] = None)
meta_leaves(pytree: Module, *, is_leaf: Optional[Callable[[Any], bool]] = None) -> List[Tuple[Optional[Dict[str, Any]], Any]]
Returns the meta of the leaves of the pytree.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pytree |
Module
|
pytree to get the meta of. |
required |
is_leaf |
Callable[[Any], bool]
|
predicate to determine if a node is a leaf. Defaults to None. |
None
|
Returns
List[Tuple[Dict[str, Any], Any]]: meta of the leaves of the pytree.
meta_flatten(pytree: Union[Module, Any], *, is_leaf: Optional[Callable[[Any], bool]] = None) -> Union[Module, Any]
Returns the meta of the Module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pytree |
Module
|
Module to get the meta of. |
required |
is_leaf |
Callable[[Any], bool]
|
predicate to determine if a node is a leaf. Defaults to None. |
None
|
Returns
Module: meta of the Module.
meta_map(f: Callable[[Any, Dict[str, Any]], Any], pytree: Union[Module, Any], *rest: Any, is_leaf: Optional[Callable[[Any], bool]] = None) -> Union[Module, Any]
Apply a function to a Module where the first argument are the pytree leaves, and the second argument are the Module metadata leaves. Args: f (Callable[[Any, Dict[str, Any]], Any]): The function to apply to the pytree. pytree (Module): The pytree to apply the function to. rest (Any, optional): Additional pytrees to apply the function to. Defaults to None. is_leaf (Callable[[Any], bool], optional): predicate to determine if a node is a leaf. Defaults to None.
Returns
Module: The transformed pytree.
meta(pytree: Module, *, is_leaf: Optional[Callable[[Any], bool]] = None) -> Module
Returns the metadata of the Module as a pytree.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pytree |
Module
|
pytree to get the metadata of. |
required |
Returns
Module: metadata of the pytree.