在 EvoX 中使用模块#
一个模块是编程中的一个基本概念,指的是一个自包含的代码单元,旨在执行特定任务或一组相关任务。
此笔记本将介绍 EvoX 中的基本模块:ModuleBase
。
模块介绍#
在用户指南的快速入门文档中,我们提到了EvoX的基本运行过程:
在EvoX中,此过程需要四个基本类:
有必要为它们提供一个统一的模块。在EvoX中,这四个类都继承自基础模块——ModuleBase
。

ModuleBase 类#
ModuleBase
类继承自 torch.nn.Module
。
在这个类中有许多方法,这里是一些重要的方法:
方法 |
签名 |
使用情况 |
---|---|---|
|
|
初始化模块。 |
|
|
模块初始化行应写在 |
`load_state_dict |
`(self, state_dict: Mapping[str, torch.Tensor], copy: bool = False, ...) |
将 |
`add_mutable |
`(self, name: str, value: Union[torch.Tensor | nn.Module, Sequence[torch.Tensor | nn.Module], Dict[str, torch.Tensor | nn.Module]]) -> None |
定义一个可变值,并将其在 |
模块的作用#
在 EvoX 中,ModuleBase
可以帮助:
包含可变值
该模块是一个面向对象的模块,可以包含可变值。
支持函数式编程
支持函数式编程模型通过 self.state_dict()
和 `self.load_state_dict(...)
标准化初始化:
基本上,预定义的子模块将被添加到此模块中,并在成员方法中访问,应该被视为“非静态成员”,而其他任何成员都应该被视为“静态成员”。
建议将非静态成员的模块初始化写在重写的 setup
方法(或任何其他成员方法)中,而不是 __init__
中。
模块使用#
具体来说,在EvoX中使用ModuleBase
有一些规则:
静态方法#
要将静态方法定义为JIT,请这样定义:
# One example of the static method defined in a Module
@jit
def func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
非静态方法#
如果一个方法使用了像 if
这样的 Python 动态控制流并希望进行 JIT,则应使用一个单独的静态方法,并使用 jit(..., trace=False)
或 `torch.jit.script_if_tracing
# Set an module inherited from the ModuleBase class
class ExampleModule(ModuleBase):
...
# An example of one method with python dynamic control flows like "if"
# The method using jit(..., trace=False)
@partial(jit, trace=False)
def static_func(x: torch.Tensor, threshold: float) -> torch.Tensor:
if x.flatten()[0] > threshold:
return torch.sin(x)
else:
return torch.tan(x)
# The method to be JIT
@jit
def jit_func(self, p: torch.Tensor) -> torch.Tensor:
return ExampleModule.static_func(p, self.threshold)
...
支持 JIT 和非 JIT 函数#
ModuleBase
通常与 jit_class
一起使用,以自动 JIT 所有非魔法成员方法:
@jit_class
class ExampleModule(ModuleBase):
# This function will be automatically JIT
def func1(self, x: torch.Tensor) -> torch.Tensor:
pass
# Use `torch.jit.ignore` to disable JIT and leave this function as Python callback
@torch.jit.ignore
def func2(self, x: torch.Tensor) -> torch.Tensor:
# you can implement pure Python logic here
pass
# JIT functions can invoke other JIT functions as well as non-JIT functions
def func3(self, x: torch.Tensor) -> torch.Tensor:
y = self.func1(x)
z = self.func2(x)
pass
示例#
从ModuleBase
继承的一个模块示例如下:
class ExampleModule(ModuleBase):
def setup(self, mut: torch.Tensor):
self.add_mutable("mut", mut)
# or
self.mut = Mutable(mut)
return self
@partial(jit, trace=False)
def static_func(x: torch.Tensor, threshold: float) -> torch.Tensor:
if x.flatten()[0] > threshold:
return torch.sin(x)
else:
return torch.tan(x)
@jit
def jit_func(self, p: torch.Tensor) -> torch.Tensor:
x = ExampleModule.static_func(p, self.threshold)
...
有关更多详细信息,请查看the Module in EvoX。