evox.core.module

Module Contents

Classes

ModuleBase

The base module for all algorithms, problems, and workflows in the library.

TransformGetSetItemToIndex

Functions

Parameter

Wraps a value as parameter with requires_grad=False. This is often used to label a value in an algorithm as a hyperparameter that can be identified by the HPOProblemWrapper.

Mutable

Wraps a value as a mutable tensor. This is often used to label a value in an algorithm as a mutable tensor that may changes during iteration(s).

compile

Fix the torch.compile’s issue with getitem and setitem that recognizes scalar indexes as .item() and causes graph breaks. Related issue: https://github.com/pytorch/pytorch/issues/124423.

vmap

Fix the torch.vmap’s issue with getitem and setitem. Related issue: https://github.com/pytorch/pytorch/issues/124423.

use_state

Transform a torch.nn.Module’s method or a torch.nn.Module into a stateful function.

API

evox.core.module.Parameter(value: evox.core.module.ParameterT, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, requires_grad: bool = False) evox.core.module.ParameterT[source]

Wraps a value as parameter with requires_grad=False. This is often used to label a value in an algorithm as a hyperparameter that can be identified by the HPOProblemWrapper.

Parameters:
  • value – The parameter value.

  • dtype – The dtype of the parameter. Defaults to None.

  • device – The device of the parameter. Defaults to None.

  • requires_grad – Whether the parameter requires gradient. Defaults to False.

Returns:

The parameter.

evox.core.module.Mutable(value: torch.Tensor, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) torch.Tensor[source]

Wraps a value as a mutable tensor. This is often used to label a value in an algorithm as a mutable tensor that may changes during iteration(s).

Parameters:
  • value – The value to be wrapped.

  • dtype – The dtype of the tensor. Defaults to None.

  • device – The device of the tensor. Defaults to None.

Returns:

The wrapped tensor.

class evox.core.module.ModuleBase(*args, **kwargs)[source]

Bases: torch.nn.Module

The base module for all algorithms, problems, and workflows in the library.

Note

To prevent ambiguity, ModuleBase.eval() is disabled.

Initialization

Initialize the ModuleBase.

Parameters:
  • *args – Variable length argument list, passed to the parent class initializer.

  • **kwargs – Arbitrary keyword arguments, passed to the parent class initializer.

Attributes: static_names (list): A list to store static member names.

eval()[source]
class evox.core.module.TransformGetSetItemToIndex[source]

Bases: torch.overrides.TorchFunctionMode

__torch_function__(func, types, args, kwargs=None)[source]
evox.core.module.compile(*args, **kwargs) Callable[source]

Fix the torch.compile’s issue with getitem and setitem that recognizes scalar indexes as .item() and causes graph breaks. Related issue: https://github.com/pytorch/pytorch/issues/124423.

evox.core.module.vmap(*args, **kwargs) Callable

Fix the torch.vmap’s issue with getitem and setitem. Related issue: https://github.com/pytorch/pytorch/issues/124423.

evox.core.module.use_state(stateful_func: Union[Callable, torch.nn.Module], tie_weights: bool = True, strict: bool = False) Callable[source]

Transform a torch.nn.Module’s method or a torch.nn.Module into a stateful function.

When using torch.nn.Module, the stateful version of the default forward method will be created. The stateful function will have a signature of fn(params_and_buffers, *args, **kwargs) -> params_and_buffers | Tuple[params_and_buffers, <original_returns>]].

Parameters:
  • stateful_func – The torch.nn.Module or a method of a torch.nn.Module to be transformed.

  • tie_weights – If True, then parameters and buffers tied in the original model will be treated as tied in the reparameterized version. Therefore, if True and different values are passed for the tied parameters and buffers, it will error. If False, it will not respect the originally tied parameters and buffers unless the values passed for both weights are the same. Defaults to True.

  • strict – If True, then the parameters and buffers passed in must match the parameters and buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will error. Defaults to False.

Returns:

A new stateful function. It takes the module’s state (a dictionary of parameters and buffers) as the first argument, followed by the original arguments. It returns the updated state. If the original function returned a value, it returns a tuple containing the updated state and the original return value.

Examples

.. code-block:: python

from evox import use_state, vmap
workflow = ... # define your workflow
stateful_step = use_state(workflow.step)
vmap_stateful_step = vmap(stateful_step)
batch_state = torch.func.stack_module_states([workflow] * 3)
new_batch_state = vmap_stateful_step(batch_state)