evox.core.module

模块内容

ModuleBase

用于库中所有算法、问题和工作流的基础模块。

TransformGetSetItemToIndex

函数

Parameter

将值包装为参数,并设置requires_grad=False。这通常用于在算法中将某个值标记为超参数,这些超参数可以通过HPOProblemWrapper识别。

Mutable

将一个值包装为可变张量。这通常用于在算法中将某个值标记为可变张量,该值可能会在迭代过程中发生改变。

compile

修复 torch.compile 中的 getitemsetitem 问题,该问题将标量索引识别为 .item() 并导致图断裂。相关问题链接:https://github.com/pytorch/pytorch/issues/124423。

vmap

修复 torch.vmap 中与 __getitem____setitem__ 相关的问题。相关问题:https://github.com/pytorch/pytorch/issues/124423。

use_state

将一个 torch.nn.Module 的方法或一个 torch.nn.Module 转换为状态函数。当使用 torch.nn.Module 时,将会创建默认 forward 方法的状态版本。状态函数的签名为 fn(params_and_buffers, *args, **kwargs) -> params_and_buffers | Tuple[params_and_buffers, <original_returns>]]

API

evox.core.module.Parameter(value: evox.core.module.ParameterT, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, requires_grad: bool = False) evox.core.module.ParameterT[源代码]

将值包装为参数,并设置requires_grad=False。这通常用于在算法中将某个值标记为超参数,这些超参数可以通过HPOProblemWrapper识别。

参数:
  • value -- 参数值。

  • dtype -- 参数的数据类型。默认为 None。

  • device -- 参数的设备。默认值为 None。

  • requires_grad -- 参数是否需要梯度。默认值为 False。

返回:

参数。

evox.core.module.Mutable(value: torch.Tensor, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) torch.Tensor[源代码]

将一个值包装为可变张量。这通常用于在算法中将某个值标记为可变张量,该值可能会在迭代过程中发生改变。

参数:
  • value -- 要包装的值。

  • dtype -- 张量的 dtype。默认值为 None。

  • device -- 张量的设备。默认为 None。

返回:

被包装的张量。

class evox.core.module.ModuleBase(*args, **kwargs)[源代码]

Bases: torch.nn.Module

用于库中所有算法、问题和工作流的基础模块。

备注

为防止产生歧义,ModuleBase.eval() 已被禁用。

初始化

初始化 ModuleBase。

参数:
  • *args -- 可变长度参数列表,传递给父类的初始化函数。

  • **kwargs -- 任意关键字参数,传递给父类初始化器。

Attributes: static_names (list): A list to store static member names.

eval()[源代码]
class evox.core.module.TransformGetSetItemToIndex[源代码]

Bases: torch.overrides.TorchFunctionMode

__torch_function__(func, types, args, kwargs=None)[源代码]
evox.core.module.compile(*args, **kwargs) Callable[源代码]

修复 torch.compile 中的 getitemsetitem 问题,该问题将标量索引识别为 .item() 并导致图断裂。相关问题链接:https://github.com/pytorch/pytorch/issues/124423。

evox.core.module.vmap(*args, **kwargs) Callable

修复 torch.vmap 中与 __getitem____setitem__ 相关的问题。相关问题:https://github.com/pytorch/pytorch/issues/124423。

evox.core.module.use_state(stateful_func: Union[Callable, torch.nn.Module]) Callable[源代码]

将一个 torch.nn.Module 的方法或一个 torch.nn.Module 转换为状态函数。当使用 torch.nn.Module 时,将会创建默认 forward 方法的状态版本。状态函数的签名为 fn(params_and_buffers, *args, **kwargs) -> params_and_buffers | Tuple[params_and_buffers, <original_returns>]]

Examples

from evox import use_state, vmap
workflow = ... # define your workflow
stateful_step = use_state(workflow.step)
vmap_stateful_step = vmap(stateful_step)
batch_state = torch.func.stack_module_states([workflow] * 3)
new_batch_state = vmap_stateful_step(batch_state)