evox.core.module¶
Module Contents¶
Classes¶
The base module for all algorithms, problems, and workflows in the library. |
|
Functions¶
Wraps a value as parameter with |
|
Wraps a value as a mutable tensor. This is often used to label a value in an algorithm as a mutable tensor that may changes during iteration(s). |
|
Fix the |
|
Fix the |
|
Transform a |
API¶
- evox.core.module.Parameter(value: evox.core.module.ParameterT, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, requires_grad: bool = False) evox.core.module.ParameterT[source]¶
Wraps a value as parameter with
requires_grad=False. This is often used to label a value in an algorithm as a hyperparameter that can be identified by theHPOProblemWrapper.- Parameters:
value – The parameter value.
dtype – The dtype of the parameter. Defaults to None.
device – The device of the parameter. Defaults to None.
requires_grad – Whether the parameter requires gradient. Defaults to False.
- Returns:
The parameter.
- evox.core.module.Mutable(value: torch.Tensor, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) torch.Tensor[source]¶
Wraps a value as a mutable tensor. This is often used to label a value in an algorithm as a mutable tensor that may changes during iteration(s).
- Parameters:
value – The value to be wrapped.
dtype – The dtype of the tensor. Defaults to None.
device – The device of the tensor. Defaults to None.
- Returns:
The wrapped tensor.
- class evox.core.module.ModuleBase(*args, **kwargs)[source]¶
Bases:
torch.nn.ModuleThe base module for all algorithms, problems, and workflows in the library.
Note
To prevent ambiguity,
ModuleBase.eval()is disabled.Initialization
Initialize the ModuleBase.
- Parameters:
*args – Variable length argument list, passed to the parent class initializer.
**kwargs – Arbitrary keyword arguments, passed to the parent class initializer.
Attributes: static_names (list): A list to store static member names.
- evox.core.module.compile(*args, **kwargs) Callable[source]¶
Fix the
torch.compile’s issue with getitem and setitem that recognizes scalar indexes as.item()and causes graph breaks. Related issue: https://github.com/pytorch/pytorch/issues/124423.
- evox.core.module.vmap(*args, **kwargs) Callable¶
Fix the
torch.vmap’s issue with getitem and setitem. Related issue: https://github.com/pytorch/pytorch/issues/124423.
- evox.core.module.use_state(stateful_func: Union[Callable, torch.nn.Module], tie_weights: bool = True, strict: bool = False) Callable[source]¶
Transform a
torch.nn.Module’s method or atorch.nn.Moduleinto a stateful function.When using
torch.nn.Module, the stateful version of the defaultforwardmethod will be created. The stateful function will have a signature offn(params_and_buffers, *args, **kwargs) -> params_and_buffers | Tuple[params_and_buffers, <original_returns>]].- Parameters:
stateful_func – The
torch.nn.Moduleor a method of atorch.nn.Moduleto be transformed.tie_weights – If True, then parameters and buffers tied in the original model will be treated as tied in the reparameterized version. Therefore, if True and different values are passed for the tied parameters and buffers, it will error. If False, it will not respect the originally tied parameters and buffers unless the values passed for both weights are the same. Defaults to True.
strict – If True, then the parameters and buffers passed in must match the parameters and buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will error. Defaults to False.
- Returns:
A new stateful function. It takes the module’s state (a dictionary of parameters and buffers) as the first argument, followed by the original arguments. It returns the updated state. If the original function returned a value, it returns a tuple containing the updated state and the original return value.
Examples
.. code-block:: python
from evox import use_state, vmap workflow = ... # define your workflow stateful_step = use_state(workflow.step) vmap_stateful_step = vmap(stateful_step) batch_state = torch.func.stack_module_states([workflow] * 3) new_batch_state = vmap_stateful_step(batch_state)