evox.core.module
¶
模块内容¶
类¶
用于库中所有算法、问题和工作流的基础模块。 |
|
函数¶
将值包装为参数,并设置 |
|
将一个值包装为可变张量。这通常用于在算法中将某个值标记为可变张量,该值可能会在迭代过程中发生改变。 |
|
修复 |
|
修复 |
|
将一个 |
API¶
- evox.core.module.Parameter(value: evox.core.module.ParameterT, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, requires_grad: bool = False) evox.core.module.ParameterT [源代码]¶
将值包装为参数,并设置
requires_grad=False
。这通常用于在算法中将某个值标记为超参数,这些超参数可以通过HPOProblemWrapper
识别。- 参数:
value -- 参数值。
dtype -- 参数的数据类型。默认为 None。
device -- 参数的设备。默认值为 None。
requires_grad -- 参数是否需要梯度。默认值为 False。
- 返回:
参数。
- evox.core.module.Mutable(value: torch.Tensor, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) torch.Tensor [源代码]¶
将一个值包装为可变张量。这通常用于在算法中将某个值标记为可变张量,该值可能会在迭代过程中发生改变。
- 参数:
value -- 要包装的值。
dtype -- 张量的 dtype。默认值为 None。
device -- 张量的设备。默认为 None。
- 返回:
被包装的张量。
- class evox.core.module.ModuleBase(*args, **kwargs)[源代码]¶
Bases:
torch.nn.Module
用于库中所有算法、问题和工作流的基础模块。
备注
为防止产生歧义,
ModuleBase.eval()
已被禁用。初始化
初始化 ModuleBase。
- 参数:
*args -- 可变长度参数列表,传递给父类的初始化函数。
**kwargs -- 任意关键字参数,传递给父类初始化器。
Attributes: static_names (list): A list to store static member names.
- evox.core.module.compile(*args, **kwargs) Callable [源代码]¶
修复
torch.compile
中的 getitem 和 setitem 问题,该问题将标量索引识别为.item()
并导致图断裂。相关问题链接:https://github.com/pytorch/pytorch/issues/124423。
- evox.core.module.vmap(*args, **kwargs) Callable ¶
修复
torch.vmap
中与__getitem__
和__setitem__
相关的问题。相关问题:https://github.com/pytorch/pytorch/issues/124423。
- evox.core.module.use_state(stateful_func: Union[Callable, torch.nn.Module]) Callable [源代码]¶
将一个
torch.nn.Module
的方法或一个torch.nn.Module
转换为状态函数。当使用torch.nn.Module
时,将会创建默认forward
方法的状态版本。状态函数的签名为fn(params_and_buffers, *args, **kwargs) -> params_and_buffers | Tuple[params_and_buffers, <original_returns>]]
。Examples
from evox import use_state, vmap workflow = ... # define your workflow stateful_step = use_state(workflow.step) vmap_stateful_step = vmap(stateful_step) batch_state = torch.func.stack_module_states([workflow] * 3) new_batch_state = vmap_stateful_step(batch_state)