在 EvoX 中使用模块#

一个模块是编程中的一个基本概念,指的是一个自包含的代码单元,旨在执行特定任务或一组相关任务。

此笔记本将介绍 EvoX 中的基本模块:ModuleBase

模块介绍#

用户指南快速入门文档中,我们提到了EvoX的基本运行过程:

Initiate an algorithm and a problem -- Set an monitor -- Initiate a workflow -- Run the workflow

在EvoX中,此过程需要四个基本类:

  • Algorithm

  • Problem

  • Monitor

  • 抱歉,我需要更多的上下文或具体的文本内容来进行翻译。请提供需要翻译的具体文本。

有必要为它们提供一个统一的模块。在EvoX中,这四个类都继承自基础模块——ModuleBase

模块 base

ModuleBase 类#

ModuleBase 类继承自 torch.nn.Module

在这个类中有许多方法,这里是一些重要的方法:

方法

签名

使用情况

__init__

(self, ...)

初始化模块。

设置

(self, ...) -> self

模块初始化行应写在 setup 的重写方法中,而不是 __init__ 中。

`load_state_dict

`(self, state_dict: Mapping[str, torch.Tensor], copy: bool = False, ...)

state_dict中的参数和缓冲区复制到此模块及其子模块中。它会覆盖torch.nn.Module.load_state_dict

`add_mutable

`(self, name: str, value: Union[torch.Tensor | nn.Module, Sequence[torch.Tensor | nn.Module], Dict[str, torch.Tensor | nn.Module]]) -> None

定义一个可变值,并将其在 self.[name] 中暴露出来,可以通过 self.[name] = [值] 来修改。

模块的作用#

在 EvoX 中,ModuleBase 可以帮助:

  • 包含可变值

该模块是一个面向对象的模块,可以包含可变值。

  • 支持函数式编程

支持函数式编程模型通过 self.state_dict() 和 `self.load_state_dict(...)

  • 标准化初始化:

基本上,预定义的子模块将被添加到此模块中,并在成员方法中访问,应该被视为“非静态成员”,而其他任何成员都应该被视为“静态成员”。

建议将非静态成员的模块初始化写在重写的 setup 方法(或任何其他成员方法)中,而不是 __init__ 中。

模块使用#

具体来说,在EvoX中使用ModuleBase有一些规则:

静态方法#

要将静态方法定义为JIT,请这样定义:

# One example of the static method defined in a Module

@jit
def func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return x + y

非静态方法#

如果一个方法使用了像 if 这样的 Python 动态控制流并希望进行 JIT,则应使用一个单独的静态方法,并使用 jit(..., trace=False) 或 `torch.jit.script_if_tracing

# Set an module inherited from the ModuleBase class
class ExampleModule(ModuleBase):

    ...

    # An example of one method with python dynamic control flows like "if"
    # The method using jit(..., trace=False)
    @partial(jit, trace=False)
    def static_func(x: torch.Tensor, threshold: float) -> torch.Tensor:
        if x.flatten()[0] > threshold:
            return torch.sin(x)
        else:
            return torch.tan(x)

    # The method to be JIT
    @jit
    def jit_func(self, p: torch.Tensor) -> torch.Tensor:
        return ExampleModule.static_func(p, self.threshold)

    ...

支持 JIT 和非 JIT 函数#

ModuleBase 通常与 jit_class 一起使用,以自动 JIT 所有非魔法成员方法:

@jit_class
class ExampleModule(ModuleBase):
    # This function will be automatically JIT
    def func1(self, x: torch.Tensor) -> torch.Tensor:
        pass

    # Use `torch.jit.ignore` to disable JIT and leave this function as Python callback
    @torch.jit.ignore
    def func2(self, x: torch.Tensor) -> torch.Tensor:
        # you can implement pure Python logic here
        pass

    # JIT functions can invoke other JIT functions as well as non-JIT functions
    def func3(self, x: torch.Tensor) -> torch.Tensor:
        y = self.func1(x)
        z = self.func2(x)
        pass

示例#

ModuleBase继承的一个模块示例如下:

class ExampleModule(ModuleBase):
        def setup(self, mut: torch.Tensor):
            self.add_mutable("mut", mut)
            # or
            self.mut = Mutable(mut)
            return self

        @partial(jit, trace=False)
        def static_func(x: torch.Tensor, threshold: float) -> torch.Tensor:
            if x.flatten()[0] > threshold:
                return torch.sin(x)
            else:
                return torch.tan(x)
        @jit
        def jit_func(self, p: torch.Tensor) -> torch.Tensor:
            x = ExampleModule.static_func(p, self.threshold)
            ...

有关更多详细信息,请查看the Module in EvoX