evox.core.module

模块内容

ModuleBase

用于库中所有算法、问题和工作流的基础模块。

TransformGetSetItemToIndex

函数

Parameter

将值包装为参数,并设置requires_grad=False。这通常用于在算法中将某个值标记为超参数,这些超参数可以通过HPOProblemWrapper识别。

Mutable

将一个值包装为可变张量。这通常用于在算法中将某个值标记为可变张量,该值可能会在迭代过程中发生改变。

compile

修复 torch.compile 中的 getitemsetitem 问题,该问题将标量索引识别为 .item() 并导致图断裂。相关问题链接:https://github.com/pytorch/pytorch/issues/124423。

vmap

修复 torch.vmap 中与 __getitem____setitem__ 相关的问题。相关问题:https://github.com/pytorch/pytorch/issues/124423。

use_state

Transform a torch.nn.Module's method or a torch.nn.Module into a stateful function.

API

evox.core.module.Parameter(value: evox.core.module.ParameterT, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, requires_grad: bool = False) evox.core.module.ParameterT[源代码]

将值包装为参数,并设置requires_grad=False。这通常用于在算法中将某个值标记为超参数,这些超参数可以通过HPOProblemWrapper识别。

参数:
  • value -- 参数值。

  • dtype -- 参数的数据类型。默认为 None。

  • device -- 参数的设备。默认值为 None。

  • requires_grad -- 参数是否需要梯度。默认值为 False。

返回:

参数。

evox.core.module.Mutable(value: torch.Tensor, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) torch.Tensor[源代码]

将一个值包装为可变张量。这通常用于在算法中将某个值标记为可变张量,该值可能会在迭代过程中发生改变。

参数:
  • value -- 要包装的值。

  • dtype -- 张量的 dtype。默认值为 None。

  • device -- 张量的设备。默认为 None。

返回:

被包装的张量。

class evox.core.module.ModuleBase(*args, **kwargs)[源代码]

Bases: torch.nn.Module

用于库中所有算法、问题和工作流的基础模块。

备注

为防止产生歧义,ModuleBase.eval() 已被禁用。

初始化

初始化 ModuleBase。

参数:
  • *args -- 可变长度参数列表,传递给父类的初始化函数。

  • **kwargs -- 任意关键字参数,传递给父类初始化器。

Attributes: static_names (list): A list to store static member names.

eval()[源代码]
class evox.core.module.TransformGetSetItemToIndex[源代码]

Bases: torch.overrides.TorchFunctionMode

__torch_function__(func, types, args, kwargs=None)[源代码]
evox.core.module.compile(*args, **kwargs) Callable[源代码]

修复 torch.compile 中的 getitemsetitem 问题,该问题将标量索引识别为 .item() 并导致图断裂。相关问题链接:https://github.com/pytorch/pytorch/issues/124423。

evox.core.module.vmap(*args, **kwargs) Callable

修复 torch.vmap 中与 __getitem____setitem__ 相关的问题。相关问题:https://github.com/pytorch/pytorch/issues/124423。

evox.core.module.use_state(stateful_func: Union[Callable, torch.nn.Module], tie_weights: bool = True, strict: bool = False) Callable[源代码]

Transform a torch.nn.Module's method or a torch.nn.Module into a stateful function.

When using torch.nn.Module, the stateful version of the default forward method will be created. The stateful function will have a signature of fn(params_and_buffers, *args, **kwargs) -> params_and_buffers | Tuple[params_and_buffers, <original_returns>]].

参数:
  • stateful_func -- The torch.nn.Module or a method of a torch.nn.Module to be transformed.

  • tie_weights -- If True, then parameters and buffers tied in the original model will be treated as tied in the reparameterized version. Therefore, if True and different values are passed for the tied parameters and buffers, it will error. If False, it will not respect the originally tied parameters and buffers unless the values passed for both weights are the same. Defaults to True.

  • strict -- If True, then the parameters and buffers passed in must match the parameters and buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will error. Defaults to False.

返回:

A new stateful function. It takes the module's state (a dictionary of parameters and buffers) as the first argument, followed by the original arguments. It returns the updated state. If the original function returned a value, it returns a tuple containing the updated state and the original return value.

Examples

.. code-block:: python

from evox import use_state, vmap
workflow = ... # define your workflow
stateful_step = use_state(workflow.step)
vmap_stateful_step = vmap(stateful_step)
batch_state = torch.func.stack_module_states([workflow] * 3)
new_batch_state = vmap_stateful_step(batch_state)