evox.core.module#

模块内容#

#

ModuleBase

该库中所有算法和问题的基础模块。

_WrapClassBase

UseStateFunc

函数#

_if_none

_is_magic

Parameter

将一个值包装为参数,requires_grad=False

Mutable

将值包装为可变张量。

assign_load_state_dict

将参数和缓冲区从 state_dict 复制到此模块及其后代。

use_state_context

一个上下文管理器,用于临时设置 using_state 的值。

trace_caching_state_context

一个上下文管理器,用于暂时设置 trace_caching_state 的值。

is_using_state

获取当前的 using_state 状态。

is_trace_caching_state

获取当前的 trace_caching_state 状态。

tracing_or_using_state

检查我们当前是否在进行 JIT 跟踪(在 torch.jit.trace 内),在 use_state_context 中,或者在 trace_caching_state 中。

_get_vars

use_state

将给定的有状态函数(在原地更改 nn.Modules)转换为一个纯函数版本,该版本接收一个额外的 state 参数(类型为 Dict[str, torch.Tensor]),并额外返回更改后的状态。

trace_impl

用于注释的辅助函数,表明被包装的方法应被视为给定target方法的trace-JIT-time代理。

vmap_impl

一个辅助函数,用于标记被包装的方法应被视为给定 target 方法的 vmap-JIT 时间代理。

jit_class

用于 JIT 脚本 (torch.jit.script) 或跟踪 (torch.jit.trace_module) 类 cls 的所有成员方法的辅助函数。

数据#

API#

evox.core.module._WRAPPING_MODULE_NAME#

'wrapping_module'

evox.core.module._if_none(a, b)[源代码]#
evox.core.module._is_magic(name: str)[源代码]#
evox.core.module.ParameterT#

TypeVar(...)

evox.core.module.Parameter(value: evox.core.module.ParameterT, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, requires_grad: bool = False) evox.core.module.ParameterT[源代码]#

将一个值包装为参数,requires_grad=False

参数:
  • value -- 参数值。

  • dtype -- 参数的数据类型。默认为 None。

  • device -- 参数的设备。默认值为 None。

  • requires_grad -- 参数是否需要梯度。默认值为 False。

返回:

参数。

evox.core.module.Mutable(value: torch.Tensor, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) torch.Tensor[源代码]#

将值包装为可变张量。

参数:
  • value -- 要包装的值。

  • dtype -- 张量的 dtype。默认值为 None。

  • device -- 张量的设备。默认为 None。

返回:

被包装的张量。

evox.core.module.assign_load_state_dict(self: torch.nn.Module, state_dict: Mapping[str, torch.Tensor])[源代码]#

将参数和缓冲区从 state_dict 复制到此模块及其后代。

该方法用来模仿 ModuleBase.load_state_dict 的行为,使得一个普通的 nn.Module 能被用于 vmap 中。

Usage:

import types
# ...
model = ... # define your model
model.load_state_dict = types.MethodType(assign_load_state_dict, model)
vmap_forward = vmap(use_state(model.forward))
jit_forward = jit(vmap_forward, trace=True, example_inputs=(vmap_forward.init_state(), ...)) # JIT trace forward pass of the model
class evox.core.module.ModuleBase(*args, **kwargs)[源代码]#

Bases: torch.nn.Module

该库中所有算法和问题的基础模块。

Notice

  1. 该模块是一种面向对象的模块,可以包含可变值。

  2. 支持功能编程模型是通过 self.state_dict(...)self.load_state_dict(...) 来实现的。

  3. 建议将非静态成员的模块初始化写在重写的 setup(或其他成员方法)中,而不是 __init__ 中。

  4. 基本上,预定义的子模块应被视为“非静态成员”,这些模块将被添加到此模块并在成员方法中稍后访问,而其他任何成员应视为“静态成员”。

Usage

  1. 静态方法将按原样定义为 JIT,例如,

@jit
def func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    pass
  1. 如果一个类成员函数具有像 if 这样的 Python 动态控制流,并且需要 JIT,则应使用一个单独的静态方法,格式为 jit(..., trace=False)torch.jit.script_if_tracing

class ExampleModule(ModuleBase):
    def setup(self, mut: torch.Tensor):
        self.add_mutable("mut", mut)
        # or
        self.mut = Mutable(mut)
        return self

    @partial(jit, trace=False)
    def static_func(x: torch.Tensor, threshold: float) -> torch.Tensor:
        if x.flatten()[0] > threshold:
            return torch.sin(x)
        else:
            return torch.tan(x)
    @jit
    def jit_func(self, p: torch.Tensor) -> torch.Tensor:
        x = ExampleModule.static_func(p, self.threshold)
        ...
  1. ModuleBase 通常与 jit_class 一起使用,以自动 JIT 所有非魔法成员方法:

@jit_class
class ExampleModule(ModuleBase):
    # This function will be automatically JIT
    def func1(self, x: torch.Tensor) -> torch.Tensor:
        pass

    # Use `torch.jit.ignore` to disable JIT and leave this function as Python callback
    @torch.jit.ignore
    def func2(self, x: torch.Tensor) -> torch.Tensor:
        # you can implement pure Python logic here
        pass

    # JIT functions can invoke other JIT functions as well as non-JIT functions
    def func3(self, x: torch.Tensor) -> torch.Tensor:
        y = self.func1(x)
        z = self.func2(x)
        pass

初始化

初始化 ModuleBase。

参数:
  • *args -- 可变长度参数列表,传递给父类的初始化函数。

  • **kwargs -- 任意关键字参数,传递给父类初始化器。

Attributes: static_names (list): A list to store static member names.

eval()[源代码]#
setup(*args, **kwargs)[源代码]#

设置模块。模块初始化行应写在重写的 setup 方法中,而不是 __init__ 中。

返回:

此模块。

Notice

静态初始化仍然可以在 __init__ 中编写,而可变初始化则不可以。因此,对于多个初始化,可以多次调用 setup

prepare_control_flow(*target_functions: Callable, keep_vars: bool = True) Tuple[Dict[str, torch.Tensor], Tuple[List[str], List[str]]][源代码]#

准备通过收集和合并指定目标函数的状态和非本地变量来准备模块的控制流状态。

用于控制流操作(控制分支等)的前向操作

参数:
  • target_functions -- 收集非本地变量的函数。

  • keep_vars -- See torch.nn.Module.state_dict(..., keep_vars). 默认为 True.

返回:

一个包含合并状态字典、状态键列表和非局部变量名称列表的元组。

抛出:

AssertionError -- 如果不是所有目标函数都是局部的、全局的或这个类的成员函数

警告

此处收集的非本地变量只能用作只读变量。对这些变量的就地修改可能不会引发任何错误,并在不经意间产生不正确的结果。

Usage

@jit_class
def ExampleModule(ModuleBase):
    # define the normal `__init__` and `test` functions, etc.

    @trace_impl(test)
    def trace_test(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        self.q = self.q + 1
        local_q = self.q * 2

        def false_branch(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
            # nonlocal local_q ## These two lines may silently produce incorrect results
            # local_q *= 1.5
            return x * y * local_q # However, using read-only nonlocals is allowed

        state, keys = self.prepare_control_flow(self.true_branch, false_branch)
        if not hasattr(self, "_switch_"):
            self._switch_ = TracingSwitch(self.true_branch, false_branch)
        state, ret = self._switch_.switch(state, (x.flatten()[0] > 0).to(dtype=torch.int), x, y)
        self.after_control_flow(state, *keys)
        return ret
after_control_flow(state: Dict[str, torch.Tensor], state_keys: List[str], nonlocal_keys: List[str]) Dict[str, torch.Tensor][源代码]#

将模块状态恢复到在给定 state 之前的状态,并返回在 prepare_control_flow 中收集的非本地变量。

此函数与 prepare_control_flow() 一起使用,以使您的控制流操作 (utils.control_flow.*) 正确处理副作用。如果控制流操作没有副作用,您可以安全地忽略此函数和 prepare_control_flow()

参数:
  • state -- 用于从中恢复模块状态的状态字典。

  • state_keys -- 表示模块状态的状态字典的键。

  • nonlocal_keys -- 表示非局部变量的状态字典的键。

返回:

在prepare_control_flow中收集的非局部变量字典。

Usage

查看 prepare_control_flow()

load_state_dict(state_dict: Mapping[str, torch.Tensor], copy: bool = False, **kwargs)[源代码]#

将参数和缓冲区从 state_dict 复制到此模块及其子模块中。覆盖 torch.nn.Module.load_state_dict

参数:
  • state_dict -- 一个包含用于更新该模块的参数和缓冲区的字典。请参阅 torch.nn.Module.load_state_dict。

  • copy -- 使用原始的torch.nn.Module.load_state_dict来复制state_dict到当前状态(copy=True)或者使用此实现来将此模块的值分配给state_dict中的值(copy=False)。默认为False。

  • **kwargs -- torch.nn.Module.load_state_dict 的原始参数。 如果 copy=False,会忽略这些参数。

返回:

如果 copy=True, 则返回 torch.nn.Module.load_state_dict 的返回值; 否则,不返回。

add_mutable(name: str, value: Union[torch.Tensor | torch.nn.Module, Sequence[torch.Tensor | torch.nn.Module], Dict[str, torch.Tensor | torch.nn.Module]]) None[源代码]#

定义一个可变值,并将其在 self.[name] 中暴露出来,可以通过 self.[name] = [值] 来修改。

参数:
  • name -- 可变值的名称。

  • value -- 可变值可以是一个元组、列表或一个 torch.Tensor 的字典。

抛出:
  • NotImplementedError -- 如果可变值的类型尚不受支持。

  • AssertionError -- 如果名称无效。

to(*args, **kwargs) evox.core.module.ModuleBase[源代码]#
__getattribute__(name)[源代码]#
__getattr_inner__(name)[源代码]#
__delattr__(name)[源代码]#
__delattr_inner__(name)[源代码]#
__setattr__(name, value)[源代码]#
__setattr_inner__(name, value)[源代码]#
__getitem__(key: Union[int, slice, str]) Union[torch.Tensor, List[torch.Tensor]][源代码]#

获取存储在此类列表模块中的可变值。

参数:

key -- 用于索引可变值的关键。

抛出:
  • IndexError -- 如果键超出范围。

  • TypeError -- 如果 key 类型错误。

返回:

索引可变值。

__setitem__(value: Union[torch.Tensor, List[torch.Tensor]], key: Union[slice, int]) None[源代码]#

设置存储在此类列表模块中的可变值。

参数:
  • value -- 新的可变值。

  • key -- 用于索引可变值的关键。

iter() Tuple[torch.Tensor][源代码]#
__sync_with__(jit_module: torch.jit.ScriptModule | None)[源代码]#
evox.core.module._using_state: contextvars.ContextVar[bool]#

'ContextVar(...)'

evox.core.module._trace_caching_state: contextvars.ContextVar[bool]#

'ContextVar(...)'

evox.core.module.use_state_context(new_use_state: bool = True)[源代码]#

一个上下文管理器,用于临时设置 using_state 的值。

当进入上下文时,using_state 的值被设置为 new_use_state,并获得一个令牌。当退出上下文时,using_state 的值被重置为之前的值。

参数:

new_use_state -- 使用 using_state 的新值。默认为 True。

Examples:

>>> with use_state_context(True):
...     assert is_using_state()
>>> assert not is_using_state()
evox.core.module.trace_caching_state_context(new_trace_caching_state: bool = True)[源代码]#

一个上下文管理器,用于暂时设置 trace_caching_state 的值。

在进入上下文时,trace_caching_state 的值被设置为 new_trace_caching_state 并获取一个令牌。当退出上下文时,trace_caching_state 的值被重置为其先前的值。

参数:

new_trace_caching_state -- trace_caching_state 的新值。默认为 True。

Examples:

>>> with trace_caching_state_context(True):
...     assert is_trace_caching_state()
>>> assert not is_trace_caching_state()
evox.core.module.is_using_state() bool[源代码]#

获取当前的 using_state 状态。

返回:

使用状态的当前状态。

evox.core.module.is_trace_caching_state() bool[源代码]#

获取当前的 trace_caching_state 状态。

返回:

当前的 trace_caching_state 状态。

evox.core.module.tracing_or_using_state()[源代码]#

检查我们当前是否在进行 JIT 跟踪(在 torch.jit.trace 内),在 use_state_context 中,或者在 trace_caching_state 中。

返回:

如果任一条件为真,则返回 True;否则返回 False。

evox.core.module._SUBMODULE_PREFIX#

'_submodule'

class evox.core.module._WrapClassBase(inner: evox.core.module.ModuleBase)[源代码]#

初始化

__str__() str[源代码]#
__repr__() str[源代码]#
__hash__() int[源代码]#
__format__(format_spec: str) str[源代码]#
__getitem__(key)[源代码]#
__setitem__(value, key)[源代码]#
__setattr__(name, value)[源代码]#
__delattr__(name)[源代码]#
__sync__()[源代码]#
evox.core.module._USE_STATE_NAME#

'use_state'

evox.core.module._STATE_ARG_NAME#

'state'

class evox.core.module.UseStateFunc[源代码]#

Bases: typing.Protocol

is_empty_state: bool#

没有可翻译的文本。

init_state(clone: bool = True) Dict[str, torch.Tensor][源代码]#

获取函数被 use_state 包装时闭包的克隆状态。

参数:

clone -- 是否克隆原始状态。默认为 True。

返回:

闭包的克隆状态。

set_state(state: Optional[Dict[str, torch.Tensor]] = None) None[源代码]#

将函数的闭包设置为给定状态。

参数:

state -- 要设置的新状态。如果 state=None,那么新状态将是当函数被 use_state 包裹时的原始状态。默认值为 None。

__call__(state: Dict[str, torch.Tensor], *args, **kwargs) Dict[str, torch.Tensor] | Tuple[Dict[str, torch.Tensor], Any][源代码]#
evox.core.module._EMPTY_NAME#

'empty'

evox.core.module._get_vars(func: Callable, *exclude, is_generator: bool = True)[源代码]#
evox.core.module.use_state(func: Callable[[], Callable] | Callable, is_generator: bool = True) evox.core.module.UseStateFunc[源代码]#

将给定的有状态函数(在原地更改 nn.Modules)转换为一个纯函数版本,该版本接收一个额外的 state 参数(类型为 Dict[str, torch.Tensor]),并额外返回更改后的状态。

参数:
  • func -- 要转换的有状态函数或其生成函数。

  • is_generator -- func 是一个函数还是一个函数生成器(例如,返回有状态函数的 lambda)。默认为 True。

返回:

func 的纯函数版本。它包含一个 init_state() -> state 属性,该属性返回 func 使用的当前状态的副本,并可用作附加状态参数的示例输入。它还包含一个 set_state(state) 属性,用于将全局状态设置为给定状态(当然不兼容 JIT)。

Notice

由于 PyTorch 不能对输入为空字典、列表或元组的函数进行 JIT 编译或矢量映射,该函数将给定函数转换为一个没有额外 state 参数(类型为 Dict[str, torch.Tensor])的函数,并且不额外返回修改后的状态。

Usage:

@jit_class
class Example(ModuleBase):

    def __init__(self, threshold=0.5):
        super().__init__()
        self.threshold = threshold
        self.sub_mod = nn.Module()
        self.sub_mod.buf = nn.Buffer(torch.zeros(()))

    def h(self, q: torch.Tensor) -> torch.Tensor:
        if q.flatten()[0] > self.threshold:
            x = torch.sin(q)
        else:
            x = torch.tan(q)
        x += self.g(x).abs()
        x *= x.shape[1]
        self.sub_mod.buf = x.sum()
        return x

    @trace_impl(h)
    def th(self, q: torch.Tensor) -> torch.Tensor:
        x += self.g(x).abs()
        x *= x.shape[1]
        self.sub_mod.buf = x.sum()
        return x

    def g(self, p: torch.Tensor) -> torch.Tensor:
        x = torch.cos(p)
        return x * p.shape[0]

t = Test()
fn = use_state(lambda: t.h, is_generator=True)
jit_fn = jit(fn, trace=True, lazy=True)
results = jit_fn(fn.init_state(), torch.rand(10, 1)
# print results, may be
print(results) # ({"self.sub_mod.buf": torch.Tensor(5.6)}, torch.Tensor([[0.56], ...]))
# IN-PLACE update all relevant variables using the given state, which is the variable `t` here.
fn.set_state(results[0])
evox.core.module._TORCHSCRIPT_MODIFIER#

'_torchscript_modifier'

evox.core.module._TRACE_WRAP_NAME#

'trace_wrapped'

evox.core.module.T#

TypeVar(...)

evox.core.module.trace_impl(target: Callable)[源代码]#

用于注释的辅助函数,表明被包装的方法应被视为给定target方法的trace-JIT-time代理。

只能在 jit_class 中的成员方法中使用。

参数:

target -- 未追踪 JIT 时调用的目标方法。

返回:

用于注解成员方法的包装函数。

Notice

  1. 目标函数和注解函数 MUST拥有同样的输入/输出签名(例如参数个数和类型);否则产生的行为是未定义的。

  2. 如果注解的函数是 vmap,则它不能包含对 self 的任何原地操作,因为这样的操作是不明确的,并且不能被编译。

Usage:

请参阅 use_state

evox.core.module._VMAP_WRAP_NAME#

vmap_wrapped

evox.core.module.vmap_impl(target: Callable)[源代码]#

一个辅助函数,用于标记被包装的方法应被视为给定 target 方法的 vmap-JIT 时间代理。

只能在 jit_class 中的成员方法中使用。

参数:

target -- 未追踪 JIT 时调用的目标方法。

返回:

用于注解成员方法的包装函数。

Notice

  1. 目标函数和注解函数 MUST拥有同样的输入/输出签名(例如参数个数和类型);否则产生的行为是未定义的。

  2. 如果注解的函数是 vmap,则它不能包含对 self 的任何原地操作,因为这样的操作是不明确的,并且不能被编译。

Usage:

请参阅 use_state

evox.core.module.ClassT#

TypeVar(...)

evox.core.module._BASE_NAME#

基类

evox.core.module.jit_class(cls: evox.core.module.ClassT, trace: bool = False) evox.core.module.ClassT[源代码]#

用于 JIT 脚本 (torch.jit.script) 或跟踪 (torch.jit.trace_module) 类 cls 的所有成员方法的辅助函数。

参数:
  • cls -- 原始类,其成员方法将被延迟 JIT。

  • trace -- 是否追踪模块或将模块脚本化。默认为 False。

返回: 被封装的类。

Notice

  1. 在许多情况下,您不需要使用 jit_class 来包装您的自定义算法或问题,工作流将为您完成此任务。

  2. 设置 trace=True 时,所有成员函数都会被有效地修改,以额外返回 self,因为副作用无法被追踪。如果您想保留副作用,请将 trace=False,并使用 use_state 函数来包装成员方法以生成纯函数。

  3. 类似地,所有模块级的操作如 self.to(...) 只能返回未包装的模块,这可能不是我们想要的。由于它们大多是就地操作,可以使用简单的 module.to(...) 来替代 module = module.to(...)

Usage:

@jit_class
class Example(ModuleBase):
    # magic methods are ignored in JIT
    def __init__(self, threshold = 0.5):
        super().__init__()
        self.threshold = threshold

    # `torch.jit.export` is automatically added to this member method
    def h(self, q: torch.Tensor) -> torch.Tensor:
        if q.flatten()[0] > self.threshold:
            x = torch.sin(q)
        else:
            x = torch.tan(q)
        return x * x.shape[1]

exp = Example(0.75)
print(exp.h(torch.rand(10, 2)))
# equivalent to
exp = torch.jit.trace_module(Example(0.75))
print(exp.h(torch.rand(10, 2)))