通过 PyTorch 高级技术优化 EvoX 的开发#

PyTorch中函数的基本优化支持#

PyTorch 提供了对函数的基本优化支持,主要通过矢量化映射(vmap)操作和即时编译(JIT)实现。这些技术分别使得批处理更加高效并提升执行性能。关于这些优化的介绍将在以下部分提供。

在 PyTorch 中通过矢量化映射支持批处理#

在PyTorch中实现的向量化映射torch.vmap是一个强大的工具,它接受一个可调用函数并返回其批处理版本。根据指定的策略,这个新函数对原始函数的操作进行向量化,从而促进高效的批处理。在EvoX中,例如,这一特性在超参数优化(HPO)中起着至关重要的作用。

import torch


def dummy_evaluation(pop_x: torch.Tensor, y: torch.Tensor):
    return pop_x * y


batched_dummy_evaluation = torch.vmap(dummy_evaluation, (0, None))

population_size = 3
individual_vector_size = 9
pop_x = torch.arange(individual_vector_size).repeat(population_size, 1)
y = torch.arange(individual_vector_size)

batched_dummy_evaluation(pop_x, y)
tensor([[ 0,  1,  4,  9, 16, 25, 36, 49, 64],
        [ 0,  1,  4,  9, 16, 25, 36, 49, 64],
        [ 0,  1,  4,  9, 16, 25, 36, 49, 64]])

在PyTorch中支持即时编译(Just-In-Time, JIT)#

在 PyTorch 中,torch.jit.tracetorch.jit.script 提供了两种不同类型的 JIT 工具,分别通过追踪和脚本化支持函数性能优化。

根据追踪策略,torch.jit.trace 方法提供了更高的解析速度和更广泛的兼容性,例如与 torch.vmap 操作的兼容性。尽管它对简单函数提供了出色的支持,但不适用于涉及动态 if-else 分支和循环控制流的复杂任务。

import functools


@functools.partial(torch.vmap, in_dims=(0, None))
def vmap_sample_func(x: torch.Tensor, y: torch.Tensor):
    return x.sum() + y

在下面的例子中,跟踪的 vmap 函数成功返回了正确的代码表示:

traced_vmap_func = torch.jit.trace(vmap_sample_func, example_inputs=(pop_x, y))
print(traced_vmap_func.code)
def vmap_sample_func(x: Tensor,
    y: Tensor) -> Tensor:
  _0 = torch.add(torch.view(torch.sum(x, [1]), [3, 1]), y)
  return _0

然而,动态的 Python 控制流无法被正确追踪,并且会发出警告:

def dynamic_control_flow(pop_x: torch.Tensor, y: torch.Tensor):
    if y.flatten()[0] > 0:
        return pop_x + y[None, :]
    else:
        return pop_x * y[None, :]

traced_dynamic_control_flow_func = torch.jit.trace(dynamic_control_flow, example_inputs=(pop_x, y))
print(traced_dynamic_control_flow_func.code)
def dynamic_control_flow(pop_x: Tensor,
    y: Tensor) -> Tensor:
  y0 = torch.flatten(y)
  _0 = torch.slice(torch.unsqueeze(y0, 0), 1, 0, 9223372036854775807)
  return torch.mul(pop_x, _0)
C:\Users\skk77\AppData\Local\Temp\ipykernel_19564\349648791.py:2: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if y.flatten()[0] > 0:

或者,采用脚本策略的 torch.jit.script 方法更适合涉及动态控制流的复杂任务,但兼容性有限。

在这个例子中,相同的 vmap_sample_func 函数,在被脚本化后,返回了一个不正确的代码表示:

scripted_vmap_func = torch.jit.script(vmap_sample_func)
print(scripted_vmap_func.code)
def vmap_sample_func(x: Tensor,
    y: Tensor) -> Tensor:
  return torch.add(torch.sum(x), y)

然而,它可以正确处理复杂的动态 Python 控制流:

def dynamic_control_flow(pop_x: torch.Tensor, y: torch.Tensor):
    if y.flatten()[0] > 0:
        return pop_x + y[None, :]
    else:
        return pop_x * y[None, :]

script_dynamic_control_flow_func = torch.jit.script(dynamic_control_flow)
print(script_dynamic_control_flow_func.code)
def dynamic_control_flow(pop_x: Tensor,
    y: Tensor) -> Tensor:
  _0 = torch.gt(torch.select(torch.flatten(y), 0, 0), 0)
  if bool(_0):
    _2 = torch.slice(torch.unsqueeze(y, 0), 1)
    _1 = torch.add(pop_x, _2)
  else:
    _3 = torch.slice(torch.unsqueeze(y, 0), 1)
    _1 = torch.mul(pop_x, _3)
  return _1

备注

torch.jit.script 依赖类型提示才能正常工作。例如,任何未注释的输入参数都会被视为 torch.Tensor,而你可以将一些输入参数注释为 python 类型,以使 torch.jit.script 按预期工作。

在PyTorch中结合使用JIT和Vectorizing Map#

根据上述介绍,当 torch.jit.tracetorch.jit.scripttorch.vmap 结合使用时,由于兼容性考虑,需要进行协调。

下图说明了 torch.jit.scripttorch.jit.tracetorch.vmap 之间的关系,突出显示了它们的相互调用路径。如果模块 A 调用模块 B,这意味着 B 可以被 A 调用。

JIT 介绍

有关在 PyTorch 上使用 JIT 和矢量化映射的详细信息,请参阅 PyTorch 官方文档中的 TorchScripttorch.vmap

EvoX中的特定优化支持#

在 EvoX 中,大多数函数是在类中定义的,特别是 ModuleBase 的子类中。为了提供更全面的优化支持,EvoX 提供了特定的增强功能。

使用 JIT 到 ModuleBase 的子类#

为了更好地理解这一部分,我们需要解释EvoX中的三个重要函数:jit_classvmapjit

jit_class 函数#

jit_class 是一个辅助函数,用于对输入类的所有成员方法进行 Just-In-Time (JIT) 脚本化的 torch.jit.script 或跟踪 (torch.jit.trace_module)。

jit_class 有两个参数:

  • cls: 原始类,其成员方法将被懒惰 JIT。

  • trace: 是否追踪模块或将模块脚本化。默认值为 False

备注

  1. 在许多情况下,不需要用jit_class包装您的自定义算法或问题,工作流将为您解决问题。

  2. trace=True 的情况下,所有成员函数都会被有效地修改为额外返回 self,因为副作用无法被追踪。如果你想保留副作用,请设置 trace=False 并使用 use_state 函数来包装成员方法以生成纯函数(use_state 函数将在下一部分中解释)。

  3. 类似地,所有模块级的操作如 self.to(...) 只能返回未包装的模块,这可能不是我们想要的。由于它们大多是就地操作,可以使用简单的 module.to(...) 来替代 module = module.to(...)

vmap 函数#

vmap 函数将给定函数矢量化映射到其映射版本。基于 torch.vmap,我们进行了许多改进,您可以查看 torch.vmap 以获取更多信息。

jit 函数#

jit 使用 torch.jit.trace (trace=True) 或 torch.jit.script (trace=False) 编译给定的 `func

该函数包装器有效处理嵌套的 JIT 和向量映射 (vmap) 表达式,如 jit(func1) -> vmap -> jit(func2),从而防止可能出现的错误。

备注

  1. trace=True 的情况下,torch.jit.trace 不能对具有不同参数的函数使用相同的示例输入参数,例如,你不能将 tensor_a, tensor_a 传递给 torch.jit.trace 版本的 `f(x: torch.Tensor, y: torch.Tensor)

  2. trace=False 的情况下,torch.jit.script 不能直接包含 vmap 表达式,请使用 jit(..., trace=True)torch.jit.trace 包装它们。

EvoX 中的模块工作 中,我们简要介绍了关于 ModuleBase 子类中方法的一些规则。现在 jit_classvmapjit 已经被解释,我们将解释更多规则并提供一些具体提示。

子类中静态方法的定义#

在子类中,静态方法要被 JIT 定义为:

# Import Pytorch
import torch

# Import the ModuleBase class from EvoX
from evox.core import ModuleBase, jit

# Set an module inherited from the ModuleBase class
class ExampleModule(ModuleBase):
    
    ...
    
    # One example of the static method defined in a Module 
    @jit
    def func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        return x + y
    
    ...
    

在子类中定义非静态方法#

如果一个方法使用了Python动态控制流,例如if,并且需要进行JIT,那么应使用一个独立的静态方法,并使用jit(..., trace=False)或`torch.jit.script_if_tracing

# Import Pytorch
import torch

# Import the ModuleBase class from EvoX
from evox.core import ModuleBase, jit

# Set an module inherited from the ModuleBase class
class ExampleModule(ModuleBase):
    
    ...
    
    # An example of one method with python dynamic control flows like "if"
    # The method using jit(..., trace=False)
    @partial(jit, trace=False)
    def static_func(x: torch.Tensor, threshold: float) -> torch.Tensor:
        if x.flatten()[0] > threshold:
            return torch.sin(x)
        else:
            return torch.tan(x)
        
    # The method to be JIT   
    @jit
    def jit_func(self, p: torch.Tensor) -> torch.Tensor:
        return ExampleModule.static_func(p, self.threshold)
    
    ...
    

备注

在Python中,动态控制流是指根据运行时的条件动态变化的控制结构。if...elif...else 条件语句、for 循环和 while 循环都是动态控制流。如果在定义 ModuleBase 子类中的非静态方法时必须使用它们,请遵循上述规则。

调用子类中的外部方法#

在子类中,外部 JIT 方法可以通过类方法调用以实现 JIT:

# Import the ModuleBase class from EvoX
from evox.core import ModuleBase

# One example of the JIT method defined outside the module 
@jit
def external_func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return x + y

# Set an module inherited from the ModuleBase class
class ExampleModule(ModuleBase):
    
    ...

    # The internal method using jit(..., trace=False)
    @partial(jit, trace=False)
    def static_func(x: torch.Tensor, threshold: float) -> torch.Tensor:

        
    # The internal static method to be JIT   
    @jit
    def jit_func(self, p: torch.Tensor) -> torch.Tensor:
        return external_func(p, p)
    
    ...
    

自动为使用 jit_class 的子类进行 JIT 处理#

ModuleBase 及其子类通常与 jit_class 一起使用,以自动JIT所有非魔术成员方法:

# Import Pytorch
import torch

# Import the ModuleBase class from EvoX
from evox.core import ModuleBase, jit_class

@jit_class
class ExampleModule(ModuleBase):
    
    ...
    
    # This function will be automatically JIT
    def automatically_JIT_func1(self, x: torch.Tensor) -> torch.Tensor:
        pass

    # Use `torch.jit.ignore` to disable JIT and leave this function as Python callback
    @torch.jit.ignore
    def no_JIT_func2(self, x: torch.Tensor) -> torch.Tensor:
        # you can implement pure Python logic here
        pass

    # JIT functions can invoke other JIT functions as well as non-JIT functions
    def automatically_JIT_func3(self, x: torch.Tensor) -> torch.Tensor:
        y = self.automatically_JIT_func1(x)
        z = self.no_JIT_func2(x)
        pass
    
    ...
    

在子类中调用外部 Vmap 包装的方法#

在子类内部,外部 vmap 包装的方法可以通过类方法调用以实现 JIT:

# Import Pytorch
import torch

# Import the ModuleBase class from EvoX
from evox.core import ModuleBase, jit, vmap


# The method to be vmap-wrapped
def external_func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return x + y.sum()

external_vmap_func = vmap(external_func, in_dims=1, out_dims=1)


# Set an module inherited from the ModuleBase class
class ExampleModule(ModuleBase):
    
    ...    
    
    # The internal class method to be JIT   
    @jit
    def jit_func(self, p: torch.Tensor) -> torch.Tensor:
        return external_vmap_func(p, p)
    
    ...
    

备注

如果方法 A 调用 vmap 包装的方法 B,那么 A 和所有调用方法 A 的方法不能再次被 vmap 包装。

子类内部的 Vmap 包装方法#

在子类内部,可以使用 trace_impl 对内部 vmap 包装的方法进行 JIT。

# Import Pytorch
import torch

# Import the ModuleBase class from EvoX
from evox.core import ModuleBase, jit, vmap, trace_impl


# Set an module inherited from the ModuleBase class
class ExampleModule(ModuleBase):
    
    ...    
    
    # The internal vmap-wrapped class method to be JIT   
    @jit
    def jit_vmap_func(self, p: torch.Tensor) -> torch.Tensor:
        
        # The original method
        # We can not vmap it
    	def func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    		return x + y
        
        
        # The method to be vmap-wrapped
        # We need to use trace_impl to rewrite the original method
        @trace_impl(func)
    	def trace_func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    		pass
        
        return vmap(func, in_dims=1, out_dims=1, trace=False)(p, p)
    
    ...
    

备注

如果一个类方法使用了trace_impl,它将仅在追踪模式下可用。关于trace_impl的更多细节将在下一部分展示。

使用 @trace_impl 和 `@vmap_impl#

在设计函数或方法时,您可能并不总是考虑它是否与JIT兼容。然而,这一特性在特定场景中变得至关重要,例如解决超参数优化(HPO)问题。有关使用EvoX部署HPO的更多详细信息,请参阅Efficient HPO with EvoX

这类问题的一个典型特征是,只需要修改算法的某些部分——例如,算法的 step 方法。这使您可以避免重写整个算法。在这种情况下,您可以使用 @trace_impl@vmap_impl 装饰器,将函数重写为指定 target 方法的 trace-JIT-time 或 vmap-JIT-time 代理。

装饰器 @trace_impl@vmap_impl 接受一个输入参数:在不进行追踪/vmap JIT 时调用的目标方法。这些装饰器适用于 jit_class 中的成员方法。

由于注释函数作为目标函数的重写版本,它必须保持相同的输入/输出签名(例如,参数的数量和类型)。否则,结果行为是未定义的。

如果注释的函数打算与vmap一起使用,它必须满足三个额外的约束:

  1. 禁止对属性进行原地操作: 算法不得包含对其属性执行原地操作的方法。

class ExampleAlgorithm(Algorithm):
    def __init__(self, ...):
        self.pop = torch.rand(10, 10)  # Attribute of the algorithm

    def step_in_place(self):  # Method with in-place operations
        self.pop.copy_(pop)

    def step_out_of_place(self):  # Method without in-place operations
        self.pop = pop
  1. 避免使用 Python 控制流: 代码逻辑不能依赖于 Python 控制流结构。要处理 Python 控制流,请使用 TracingCondTracingWhileTracingSwitch

@jit_class
class ExampleAlgorithm(Algorithm):
    def __init__(self, pop_size, ...):
        super().__init__()
        self.pop = torch.rand(pop_size, pop_size)

    def strategy_1(self):  # One update strategy
        new_pop = self.pop * self.pop
        self.pop = new_pop

    def strategy_2(self):  # Another update strategy
        new_pop = self.pop + self.pop
        self.pop = new_pop

    def step(self):
        control_number = torch.rand()
        if control_number < 0.5:  # Conditional control
            self.strategy_1()
        else:
            self.strategy_2()

    @trace_impl(step)  # Rewrite step function for vmap support
    def trace_step_without_operations_to_self(self):
        pop = torch.rand(self.pop_size, self.dim, dtype=self.lb.dtype, device=self.lb.device)
        pop = pop * (self.ub - self.lb)[None, :] + self.lb[None, :]
        pop = pop * self.hp[0]
        control_number = torch.rand()
        cond = control_number < 0.5
        branches = (self.strategy_1, self.strategy_2)
        state, names = self.prepare_control_flow(*branches)  # Utilize state to track self.pop
        _if_else_ = TracingCond(*branches)
        state = _if_else_.cond(state, cond, pop)
        self.after_control_flow(state, *names)

    @trace_impl(step)
    def trace_step_with_operations_to_self(self):
        pop = torch.rand(self.pop_size, self.dim, dtype=self.lb.dtype, device=self.lb.device)
        pop = pop * (self.ub - self.lb)[None, :] + self.lb[None, :]
        pop = pop * self.hp[0]
        control_number = torch.rand()
        cond = control_number < 0.5
        _if_else_ = TracingCond(lambda p: p * p, lambda p: p + p)  # No need to track self.pop
        pop = _if_else_.cond(cond, pop)
        self.pop = pop
  1. 避免对 self 进行就地操作: 向量化映射对 self 的就地操作定义不明确,无法编译。即使成功编译,您仍可能会悄无声息地得到错误的结果。

使用 use_state#

use_state 将给定的有状态函数(对 nn.Module 进行就地更改)转换为纯函数版本,该版本接收一个额外的 state 参数(类型为 Dict[str, torch.Tensor]),并返回更改后的状态。

输入 func 是要转换的有状态函数或其生成器函数,而 is_generator 指定 func 是一个函数还是一个函数生成器(例如,返回有状态函数的 lambda)。默认值为 `True

以下是一个简单的例子:

@jit_class
class Example(ModuleBase):
    def __init__(self, threshold=0.5):
        super().__init__()
        self.threshold = threshold
        self.sub_mod = nn.Module()
        self.sub_mod.buf = nn.Buffer(torch.zeros(()))

    def h(self, q: torch.Tensor) -> torch.Tensor:
        if q.flatten()[0] > self.threshold:
            x = torch.sin(q)
        else:
            x = torch.tan(q)
        x += self.g(x).abs()
        x *= x.shape[1]
        self.sub_mod.buf = x.sum()
        return x

    @trace_impl(h)
    def th(self, q: torch.Tensor) -> torch.Tensor:
        x += self.g(x).abs()
        x *= x.shape[1]
        self.sub_mod.buf = x.sum()
        return x

    def g(self, p: torch.Tensor) -> torch.Tensor:
        x = torch.cos(p)
        return x * p.shape[0]

fn = use_state(lambda: t.h, is_generator=True)
jit_fn = jit(fn, trace=True, lazy=True)
results = jit_fn(fn.init_state(), torch.rand(10, 1))
print(results)  # ({"self.sub_mod.buf": torch.Tensor(5.6)}, torch.Tensor([[0.56], ...]))

# IN-PLACE update all relevant variables using the given state
fn.set_state(results[0])

使用 `core._vmap_fix#

模块 _vmap_fix 提供了有用的函数。在自动导入后,_vmap_fix 使 torch.vmap 能够被 torch.jit.trace 正确追踪,同时解决了在 vmap 过程中无法正确追踪的随机数处理等问题。它还提供了 debug_print 函数,允许在 vmap 和追踪过程中动态打印 Tensor 值。

详细信息可以在_vmap_fix文档中找到。