__all__ = [
"Parameter",
"Mutable",
"ModuleBase",
"TransformGetSetItemToIndex",
"compile",
"vmap",
"use_state",
]
from functools import wraps
from typing import Any, Callable, Dict, Optional, Sequence, TypeVar, Union
import torch
import torch.nn as nn
from torch.overrides import TorchFunctionMode
ParameterT = TypeVar("ParameterT", torch.Tensor, int, float, complex)
[文档]
def Parameter(
value: ParameterT,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = False,
) -> ParameterT:
"""Wraps a value as parameter with `requires_grad=False`.
This is often used to label a value in an algorithm as a hyperparameter that can be identified by the `HPOProblemWrapper`.
:param value: The parameter value.
:param dtype: The dtype of the parameter. Defaults to None.
:param device: The device of the parameter. Defaults to None.
:param requires_grad: Whether the parameter requires gradient. Defaults to False.
:return: The parameter.
"""
return nn.Parameter(
(
value.to(dtype=dtype, device=device)
if isinstance(value, torch.Tensor)
else torch.as_tensor(value, dtype=dtype, device=device)
),
requires_grad=requires_grad,
)
[文档]
def Mutable(value: torch.Tensor, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> torch.Tensor:
"""Wraps a value as a mutable tensor.
This is often used to label a value in an algorithm as a mutable tensor that may changes during iteration(s).
:param value: The value to be wrapped.
:param dtype: The dtype of the tensor. Defaults to None.
:param device: The device of the tensor. Defaults to None.
:return: The wrapped tensor.
"""
return nn.Buffer(value.to(dtype=dtype, device=device))
[文档]
class ModuleBase(nn.Module):
"""
The base module for all algorithms, problems, and workflows in the library.
```{note}
To prevent ambiguity, `ModuleBase.eval()` is disabled.
```
"""
def __init__(self, *args, **kwargs):
"""Initialize the ModuleBase.
:param *args: Variable length argument list, passed to the parent class initializer.
:param **kwargs: Arbitrary keyword arguments, passed to the parent class initializer.
Attributes:
__static_names__ (list): A list to store static member names.
"""
super().__init__(*args, **kwargs)
self.train(False)
[文档]
def eval(self):
assert False, "`ModuleBase.eval()` shall never be invoked to prevent ambiguity."
def _transform_scalar_index(ori_index: Sequence[Any | torch.Tensor] | Any | torch.Tensor):
if isinstance(ori_index, Sequence):
index = tuple(ori_index)
else:
index = (ori_index,)
any_scalar_tensor = False
new_index = []
for idx in index:
if isinstance(idx, torch.Tensor) and idx.ndim == 0:
new_index.append(idx[None])
any_scalar_tensor = True
else:
new_index.append(idx)
if not isinstance(ori_index, Sequence):
new_index = new_index[0]
return new_index, any_scalar_tensor
# We still need a fix for the vmap
# related issue: https://github.com/pytorch/pytorch/issues/124423
@wraps(torch.compile)
def compile(*args, **kwargs) -> Callable:
"""Fix the `torch.compile`'s issue with __getitem__ and __setitem__
that recognizes scalar indexes as `.item()` and causes graph breaks.
Related issue: https://github.com/pytorch/pytorch/issues/124423.
"""
with TransformGetSetItemToIndex():
compiled = torch.compile(*args, **kwargs)
def wrapper(*args, **kwargs):
with TransformGetSetItemToIndex():
return compiled(*args, **kwargs)
wrapper.__wrapped__ = compiled
return wrapper
@wraps(torch.vmap)
def vmap(*args, **kwargs) -> Callable:
"""Fix the `torch.vmap`'s issue with __getitem__ and __setitem__.
Related issue: https://github.com/pytorch/pytorch/issues/124423.
"""
vmapped = torch.vmap(*args, **kwargs)
def wrapper(*args, **kwargs):
with TransformGetSetItemToIndex():
return vmapped(*args, **kwargs)
return wrapper
class _ReplaceForwardModule(nn.Module):
def __init__(self, module: nn.Module, new_forward: Callable):
super().__init__()
self._inner_module = module
self.new_forward = new_forward
def forward(self, *args, **kwargs):
return self.new_forward(self._inner_module, *args, **kwargs)
[文档]
def use_state(stateful_func: Union[Callable, nn.Module], tie_weights: bool = True, strict: bool = False) -> Callable:
"""Transform a `torch.nn.Module`'s method or a `torch.nn.Module` into a stateful function.
When using `torch.nn.Module`, the stateful version of the default `forward` method will be created.
The stateful function will have a signature of `fn(params_and_buffers, *args, **kwargs) -> params_and_buffers | Tuple[params_and_buffers, <original_returns>]]`.
:param stateful_func: The ``torch.nn.Module`` or a method of a ``torch.nn.Module`` to be transformed.
:param tie_weights: If True, then parameters and buffers tied in the original model will be
treated as tied in the reparameterized version. Therefore, if True and
different values are passed for the tied parameters and buffers, it will error.
If False, it will not respect the originally tied parameters and buffers
unless the values passed for both weights are the same. Defaults to True.
:param strict: If True, then the parameters and buffers passed in must match the parameters
and buffers in the original module. Therefore, if True and there are any
missing or unexpected keys, it will error. Defaults to False.
:return: A new stateful function. It takes the module's state (a dictionary of parameters
and buffers) as the first argument, followed by the original arguments. It returns
the updated state. If the original function returned a value, it returns a tuple
containing the updated state and the original return value.
## Examples
.. code-block:: python
from evox import use_state, vmap
workflow = ... # define your workflow
stateful_step = use_state(workflow.step)
vmap_stateful_step = vmap(stateful_step)
batch_state = torch.func.stack_module_states([workflow] * 3)
new_batch_state = vmap_stateful_step(batch_state)
"""
if not isinstance(stateful_func, torch.nn.Module):
module: torch.nn.Module = stateful_func.__self__
assert isinstance(
module, torch.nn.Module
), "`stateful_func` must be a `torch.nn.Module` or a method of a `torch.nn.Module`"
new_forward = stateful_func.__func__
else:
module = stateful_func
new_forward = module.forward.__func__
module = _ReplaceForwardModule(module, new_forward)
def wrapper(params_and_buffers: Dict[str, torch.Tensor], *args, **kwargs):
params_and_buffers = {("_inner_module." + k): v for k, v in params_and_buffers.items()}
output = torch.func.functional_call(module, params_and_buffers, args, kwargs, tie_weights=tie_weights, strict=strict)
params_and_buffers = {k[len("_inner_module.") :]: v for k, v in params_and_buffers.items()}
if output is None:
return params_and_buffers
else:
return params_and_buffers, output
return wrapper