evox.core.module
#
模块内容#
类#
该库中所有算法和问题的基础模块。 |
|
函数#
将一个值包装为参数, |
|
将值包装为可变张量。 |
|
将参数和缓冲区从 state_dict 复制到此模块及其后代。 |
|
一个上下文管理器,用于临时设置 |
|
一个上下文管理器,用于暂时设置 |
|
获取当前的 |
|
获取当前的 |
|
检查我们当前是否在进行 JIT 跟踪(在 |
|
将给定的有状态函数(在原地更改 |
|
用于注释的辅助函数,表明被包装的方法应被视为给定 |
|
一个辅助函数,用于标记被包装的方法应被视为给定 |
|
用于 JIT 脚本 ( |
数据#
API#
- evox.core.module._WRAPPING_MODULE_NAME#
'wrapping_module'
- evox.core.module.ParameterT#
TypeVar(...)
- evox.core.module.Parameter(value: evox.core.module.ParameterT, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, requires_grad: bool = False) evox.core.module.ParameterT [源代码]#
将一个值包装为参数,
requires_grad=False
。- 参数:
value -- 参数值。
dtype -- 参数的数据类型。默认为 None。
device -- 参数的设备。默认值为 None。
requires_grad -- 参数是否需要梯度。默认值为 False。
- 返回:
参数。
- evox.core.module.Mutable(value: torch.Tensor, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) torch.Tensor [源代码]#
将值包装为可变张量。
- 参数:
value -- 要包装的值。
dtype -- 张量的 dtype。默认值为 None。
device -- 张量的设备。默认为 None。
- 返回:
被包装的张量。
- evox.core.module.assign_load_state_dict(self: torch.nn.Module, state_dict: Mapping[str, torch.Tensor])[源代码]#
将参数和缓冲区从 state_dict 复制到此模块及其后代。
该方法用来模仿
ModuleBase.load_state_dict
的行为,使得一个普通的nn.Module
能被用于vmap
中。Usage:
import types # ... model = ... # define your model model.load_state_dict = types.MethodType(assign_load_state_dict, model) vmap_forward = vmap(use_state(model.forward)) jit_forward = jit(vmap_forward, trace=True, example_inputs=(vmap_forward.init_state(), ...)) # JIT trace forward pass of the model
- class evox.core.module.ModuleBase(*args, **kwargs)[源代码]#
Bases:
torch.nn.Module
该库中所有算法和问题的基础模块。
Notice
该模块是一种面向对象的模块,可以包含可变值。
支持功能编程模型是通过
self.state_dict(...)
和self.load_state_dict(...)
来实现的。建议将非静态成员的模块初始化写在重写的
setup
(或其他成员方法)中,而不是__init__
中。基本上,预定义的子模块应被视为“非静态成员”,这些模块将被添加到此模块并在成员方法中稍后访问,而其他任何成员应视为“静态成员”。
Usage
静态方法将按原样定义为 JIT,例如,
@jit def func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: pass
如果一个类成员函数具有像
if
这样的 Python 动态控制流,并且需要 JIT,则应使用一个单独的静态方法,格式为jit(..., trace=False)
或torch.jit.script_if_tracing
:
class ExampleModule(ModuleBase): def setup(self, mut: torch.Tensor): self.add_mutable("mut", mut) # or self.mut = Mutable(mut) return self @partial(jit, trace=False) def static_func(x: torch.Tensor, threshold: float) -> torch.Tensor: if x.flatten()[0] > threshold: return torch.sin(x) else: return torch.tan(x) @jit def jit_func(self, p: torch.Tensor) -> torch.Tensor: x = ExampleModule.static_func(p, self.threshold) ...
ModuleBase
通常与jit_class
一起使用,以自动 JIT 所有非魔法成员方法:
@jit_class class ExampleModule(ModuleBase): # This function will be automatically JIT def func1(self, x: torch.Tensor) -> torch.Tensor: pass # Use `torch.jit.ignore` to disable JIT and leave this function as Python callback @torch.jit.ignore def func2(self, x: torch.Tensor) -> torch.Tensor: # you can implement pure Python logic here pass # JIT functions can invoke other JIT functions as well as non-JIT functions def func3(self, x: torch.Tensor) -> torch.Tensor: y = self.func1(x) z = self.func2(x) pass
初始化
初始化 ModuleBase。
- 参数:
*args -- 可变长度参数列表,传递给父类的初始化函数。
**kwargs -- 任意关键字参数,传递给父类初始化器。
Attributes: static_names (list): A list to store static member names.
- setup(*args, **kwargs)[源代码]#
设置模块。模块初始化行应写在重写的
setup
方法中,而不是__init__
中。- 返回:
此模块。
Notice
静态初始化仍然可以在
__init__
中编写,而可变初始化则不可以。因此,对于多个初始化,可以多次调用setup
。
- prepare_control_flow(*target_functions: Callable, keep_vars: bool = True) Tuple[Dict[str, torch.Tensor], Tuple[List[str], List[str]]] [源代码]#
准备通过收集和合并指定目标函数的状态和非本地变量来准备模块的控制流状态。
用于控制流操作(控制分支等)的前向操作
- 参数:
target_functions -- 收集非本地变量的函数。
keep_vars -- See torch.nn.Module.state_dict(..., keep_vars). 默认为 True.
- 返回:
一个包含合并状态字典、状态键列表和非局部变量名称列表的元组。
- 抛出:
AssertionError -- 如果不是所有目标函数都是局部的、全局的或这个类的成员函数
警告
此处收集的非本地变量只能用作只读变量。对这些变量的就地修改可能不会引发任何错误,并在不经意间产生不正确的结果。
Usage
@jit_class def ExampleModule(ModuleBase): # define the normal `__init__` and `test` functions, etc. @trace_impl(test) def trace_test(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: self.q = self.q + 1 local_q = self.q * 2 def false_branch(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # nonlocal local_q ## These two lines may silently produce incorrect results # local_q *= 1.5 return x * y * local_q # However, using read-only nonlocals is allowed state, keys = self.prepare_control_flow(self.true_branch, false_branch) if not hasattr(self, "_switch_"): self._switch_ = TracingSwitch(self.true_branch, false_branch) state, ret = self._switch_.switch(state, (x.flatten()[0] > 0).to(dtype=torch.int), x, y) self.after_control_flow(state, *keys) return ret
- after_control_flow(state: Dict[str, torch.Tensor], state_keys: List[str], nonlocal_keys: List[str]) Dict[str, torch.Tensor] [源代码]#
将模块状态恢复到在给定
state
之前的状态,并返回在prepare_control_flow
中收集的非本地变量。此函数与
prepare_control_flow()
一起使用,以使您的控制流操作 (utils.control_flow.*
) 正确处理副作用。如果控制流操作没有副作用,您可以安全地忽略此函数和prepare_control_flow()
。- 参数:
state -- 用于从中恢复模块状态的状态字典。
state_keys -- 表示模块状态的状态字典的键。
nonlocal_keys -- 表示非局部变量的状态字典的键。
- 返回:
在prepare_control_flow中收集的非局部变量字典。
Usage
查看
prepare_control_flow()
。
- load_state_dict(state_dict: Mapping[str, torch.Tensor], copy: bool = False, **kwargs)[源代码]#
将参数和缓冲区从 state_dict 复制到此模块及其子模块中。覆盖
torch.nn.Module.load_state_dict
。- 参数:
state_dict -- 一个包含用于更新该模块的参数和缓冲区的字典。请参阅 torch.nn.Module.load_state_dict。
copy -- 使用原始的torch.nn.Module.load_state_dict来复制state_dict到当前状态(copy=True)或者使用此实现来将此模块的值分配给state_dict中的值(copy=False)。默认为False。
**kwargs -- torch.nn.Module.load_state_dict 的原始参数。 如果 copy=False,会忽略这些参数。
- 返回:
如果 copy=True, 则返回 torch.nn.Module.load_state_dict 的返回值; 否则,不返回。
- add_mutable(name: str, value: Union[torch.Tensor | torch.nn.Module, Sequence[torch.Tensor | torch.nn.Module], Dict[str, torch.Tensor | torch.nn.Module]]) None [源代码]#
定义一个可变值,并将其在
self.[name]
中暴露出来,可以通过self.[name] = [值]
来修改。- 参数:
name -- 可变值的名称。
value -- 可变值可以是一个元组、列表或一个 torch.Tensor 的字典。
- 抛出:
NotImplementedError -- 如果可变值的类型尚不受支持。
AssertionError -- 如果名称无效。
- to(*args, **kwargs) evox.core.module.ModuleBase [源代码]#
- __getitem__(key: Union[int, slice, str]) Union[torch.Tensor, List[torch.Tensor]] [源代码]#
获取存储在此类列表模块中的可变值。
- 参数:
key -- 用于索引可变值的关键。
- 抛出:
IndexError -- 如果键超出范围。
TypeError -- 如果 key 类型错误。
- 返回:
索引可变值。
- evox.core.module._using_state: contextvars.ContextVar[bool]#
'ContextVar(...)'
- evox.core.module._trace_caching_state: contextvars.ContextVar[bool]#
'ContextVar(...)'
- evox.core.module.use_state_context(new_use_state: bool = True)[源代码]#
一个上下文管理器,用于临时设置
using_state
的值。当进入上下文时,
using_state
的值被设置为new_use_state
,并获得一个令牌。当退出上下文时,using_state
的值被重置为之前的值。- 参数:
new_use_state -- 使用
using_state
的新值。默认为 True。
Examples:
>>> with use_state_context(True): ... assert is_using_state() >>> assert not is_using_state()
- evox.core.module.trace_caching_state_context(new_trace_caching_state: bool = True)[源代码]#
一个上下文管理器,用于暂时设置
trace_caching_state
的值。在进入上下文时,
trace_caching_state
的值被设置为new_trace_caching_state
并获取一个令牌。当退出上下文时,trace_caching_state
的值被重置为其先前的值。- 参数:
new_trace_caching_state -- trace_caching_state 的新值。默认为 True。
Examples:
>>> with trace_caching_state_context(True): ... assert is_trace_caching_state() >>> assert not is_trace_caching_state()
- evox.core.module.is_trace_caching_state() bool [源代码]#
获取当前的
trace_caching_state
状态。- 返回:
当前的 trace_caching_state 状态。
- evox.core.module.tracing_or_using_state()[源代码]#
检查我们当前是否在进行 JIT 跟踪(在
torch.jit.trace
内),在use_state_context
中,或者在trace_caching_state
中。- 返回:
如果任一条件为真,则返回 True;否则返回 False。
- evox.core.module._SUBMODULE_PREFIX#
'_submodule'
- class evox.core.module._WrapClassBase(inner: evox.core.module.ModuleBase)[源代码]#
初始化
- evox.core.module._USE_STATE_NAME#
'use_state'
- evox.core.module._STATE_ARG_NAME#
'state'
- class evox.core.module.UseStateFunc[源代码]#
Bases:
typing.Protocol
- is_empty_state: bool#
没有可翻译的文本。
- init_state(clone: bool = True) Dict[str, torch.Tensor] [源代码]#
获取函数被
use_state
包装时闭包的克隆状态。- 参数:
clone -- 是否克隆原始状态。默认为 True。
- 返回:
闭包的克隆状态。
- evox.core.module._EMPTY_NAME#
'empty'
- evox.core.module.use_state(func: Callable[[], Callable] | Callable, is_generator: bool = True) evox.core.module.UseStateFunc [源代码]#
将给定的有状态函数(在原地更改
nn.Module
s)转换为一个纯函数版本,该版本接收一个额外的state
参数(类型为Dict[str, torch.Tensor]
),并额外返回更改后的状态。- 参数:
func -- 要转换的有状态函数或其生成函数。
is_generator -- func 是一个函数还是一个函数生成器(例如,返回有状态函数的 lambda)。默认为 True。
- 返回:
func
的纯函数版本。它包含一个init_state() -> state
属性,该属性返回func
使用的当前状态的副本,并可用作附加状态参数的示例输入。它还包含一个set_state(state)
属性,用于将全局状态设置为给定状态(当然不兼容 JIT)。
Notice
由于 PyTorch 不能对输入为空字典、列表或元组的函数进行 JIT 编译或矢量映射,该函数将给定函数转换为一个没有额外
state
参数(类型为Dict[str, torch.Tensor]
)的函数,并且不额外返回修改后的状态。Usage:
@jit_class class Example(ModuleBase): def __init__(self, threshold=0.5): super().__init__() self.threshold = threshold self.sub_mod = nn.Module() self.sub_mod.buf = nn.Buffer(torch.zeros(())) def h(self, q: torch.Tensor) -> torch.Tensor: if q.flatten()[0] > self.threshold: x = torch.sin(q) else: x = torch.tan(q) x += self.g(x).abs() x *= x.shape[1] self.sub_mod.buf = x.sum() return x @trace_impl(h) def th(self, q: torch.Tensor) -> torch.Tensor: x += self.g(x).abs() x *= x.shape[1] self.sub_mod.buf = x.sum() return x def g(self, p: torch.Tensor) -> torch.Tensor: x = torch.cos(p) return x * p.shape[0] t = Test() fn = use_state(lambda: t.h, is_generator=True) jit_fn = jit(fn, trace=True, lazy=True) results = jit_fn(fn.init_state(), torch.rand(10, 1) # print results, may be print(results) # ({"self.sub_mod.buf": torch.Tensor(5.6)}, torch.Tensor([[0.56], ...])) # IN-PLACE update all relevant variables using the given state, which is the variable `t` here. fn.set_state(results[0])
- evox.core.module._TORCHSCRIPT_MODIFIER#
'_torchscript_modifier'
- evox.core.module._TRACE_WRAP_NAME#
'trace_wrapped'
- evox.core.module.T#
TypeVar(...)
- evox.core.module.trace_impl(target: Callable)[源代码]#
用于注释的辅助函数,表明被包装的方法应被视为给定
target
方法的trace-JIT-time代理。只能在
jit_class
中的成员方法中使用。- 参数:
target -- 未追踪 JIT 时调用的目标方法。
- 返回:
用于注解成员方法的包装函数。
Notice
目标函数和注解函数 MUST拥有同样的输入/输出签名(例如参数个数和类型);否则产生的行为是未定义的。
如果注解的函数是
vmap
,则它不能包含对self
的任何原地操作,因为这样的操作是不明确的,并且不能被编译。
Usage:
请参阅
use_state
。
- evox.core.module._VMAP_WRAP_NAME#
vmap_wrapped
- evox.core.module.vmap_impl(target: Callable)[源代码]#
一个辅助函数,用于标记被包装的方法应被视为给定
target
方法的 vmap-JIT 时间代理。只能在
jit_class
中的成员方法中使用。- 参数:
target -- 未追踪 JIT 时调用的目标方法。
- 返回:
用于注解成员方法的包装函数。
Notice
目标函数和注解函数 MUST拥有同样的输入/输出签名(例如参数个数和类型);否则产生的行为是未定义的。
如果注解的函数是
vmap
,则它不能包含对self
的任何原地操作,因为这样的操作是不明确的,并且不能被编译。
Usage:
请参阅
use_state
。
- evox.core.module.ClassT#
TypeVar(...)
- evox.core.module._BASE_NAME#
基类
- evox.core.module.jit_class(cls: evox.core.module.ClassT, trace: bool = False) evox.core.module.ClassT [源代码]#
用于 JIT 脚本 (
torch.jit.script
) 或跟踪 (torch.jit.trace_module
) 类cls
的所有成员方法的辅助函数。- 参数:
cls -- 原始类,其成员方法将被延迟 JIT。
trace -- 是否追踪模块或将模块脚本化。默认为 False。
返回: 被封装的类。
Notice
在许多情况下,您不需要使用
jit_class
来包装您的自定义算法或问题,工作流将为您完成此任务。设置
trace=True
时,所有成员函数都会被有效地修改,以额外返回self
,因为副作用无法被追踪。如果您想保留副作用,请将trace=False
,并使用use_state
函数来包装成员方法以生成纯函数。类似地,所有模块级的操作如
self.to(...)
只能返回未包装的模块,这可能不是我们想要的。由于它们大多是就地操作,可以使用简单的module.to(...)
来替代module = module.to(...)
。
Usage:
@jit_class class Example(ModuleBase): # magic methods are ignored in JIT def __init__(self, threshold = 0.5): super().__init__() self.threshold = threshold # `torch.jit.export` is automatically added to this member method def h(self, q: torch.Tensor) -> torch.Tensor: if q.flatten()[0] > self.threshold: x = torch.sin(q) else: x = torch.tan(q) return x * x.shape[1] exp = Example(0.75) print(exp.h(torch.rand(10, 2))) # equivalent to exp = torch.jit.trace_module(Example(0.75)) print(exp.h(torch.rand(10, 2)))