通过 PyTorch 高级技术优化 EvoX 的开发#
PyTorch中函数的基本优化支持#
PyTorch 提供了对函数的基本优化支持,主要通过矢量化映射(vmap)操作和即时编译(JIT)实现。这些技术分别使得批处理更加高效并提升执行性能。关于这些优化的介绍将在以下部分提供。
在 PyTorch 中通过矢量化映射支持批处理#
在PyTorch中实现的向量化映射torch.vmap
是一个强大的工具,它接受一个可调用函数并返回其批处理版本。根据指定的策略,这个新函数对原始函数的操作进行向量化,从而促进高效的批处理。在EvoX中,例如,这一特性在超参数优化(HPO)中起着至关重要的作用。
import torch
def dummy_evaluation(pop_x: torch.Tensor, y: torch.Tensor):
return pop_x * y
batched_dummy_evaluation = torch.vmap(dummy_evaluation, (0, None))
population_size = 3
individual_vector_size = 9
pop_x = torch.arange(individual_vector_size).repeat(population_size, 1)
y = torch.arange(individual_vector_size)
batched_dummy_evaluation(pop_x, y)
tensor([[ 0, 1, 4, 9, 16, 25, 36, 49, 64],
[ 0, 1, 4, 9, 16, 25, 36, 49, 64],
[ 0, 1, 4, 9, 16, 25, 36, 49, 64]])
在PyTorch中支持即时编译(Just-In-Time, JIT)#
在 PyTorch 中,torch.jit.trace
和 torch.jit.script
提供了两种不同类型的 JIT 工具,分别通过追踪和脚本化支持函数性能优化。
根据追踪策略,torch.jit.trace
方法提供了更高的解析速度和更广泛的兼容性,例如与 torch.vmap
操作的兼容性。尽管它对简单函数提供了出色的支持,但不适用于涉及动态 if-else 分支和循环控制流的复杂任务。
import functools
@functools.partial(torch.vmap, in_dims=(0, None))
def vmap_sample_func(x: torch.Tensor, y: torch.Tensor):
return x.sum() + y
在下面的例子中,跟踪的 vmap
函数成功返回了正确的代码表示:
traced_vmap_func = torch.jit.trace(vmap_sample_func, example_inputs=(pop_x, y))
print(traced_vmap_func.code)
def vmap_sample_func(x: Tensor,
y: Tensor) -> Tensor:
_0 = torch.add(torch.view(torch.sum(x, [1]), [3, 1]), y)
return _0
然而,动态的 Python 控制流无法被正确追踪,并且会发出警告:
def dynamic_control_flow(pop_x: torch.Tensor, y: torch.Tensor):
if y.flatten()[0] > 0:
return pop_x + y[None, :]
else:
return pop_x * y[None, :]
traced_dynamic_control_flow_func = torch.jit.trace(dynamic_control_flow, example_inputs=(pop_x, y))
print(traced_dynamic_control_flow_func.code)
def dynamic_control_flow(pop_x: Tensor,
y: Tensor) -> Tensor:
y0 = torch.flatten(y)
_0 = torch.slice(torch.unsqueeze(y0, 0), 1, 0, 9223372036854775807)
return torch.mul(pop_x, _0)
C:\Users\skk77\AppData\Local\Temp\ipykernel_19564\349648791.py:2: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if y.flatten()[0] > 0:
或者,采用脚本策略的 torch.jit.script
方法更适合涉及动态控制流的复杂任务,但兼容性有限。
在这个例子中,相同的 vmap_sample_func
函数,在被脚本化后,返回了一个不正确的代码表示:
scripted_vmap_func = torch.jit.script(vmap_sample_func)
print(scripted_vmap_func.code)
def vmap_sample_func(x: Tensor,
y: Tensor) -> Tensor:
return torch.add(torch.sum(x), y)
然而,它可以正确处理复杂的动态 Python 控制流:
def dynamic_control_flow(pop_x: torch.Tensor, y: torch.Tensor):
if y.flatten()[0] > 0:
return pop_x + y[None, :]
else:
return pop_x * y[None, :]
script_dynamic_control_flow_func = torch.jit.script(dynamic_control_flow)
print(script_dynamic_control_flow_func.code)
def dynamic_control_flow(pop_x: Tensor,
y: Tensor) -> Tensor:
_0 = torch.gt(torch.select(torch.flatten(y), 0, 0), 0)
if bool(_0):
_2 = torch.slice(torch.unsqueeze(y, 0), 1)
_1 = torch.add(pop_x, _2)
else:
_3 = torch.slice(torch.unsqueeze(y, 0), 1)
_1 = torch.mul(pop_x, _3)
return _1
备注
torch.jit.script
依赖类型提示才能正常工作。例如,任何未注释的输入参数都会被视为 torch.Tensor
,而你可以将一些输入参数注释为 python 类型,以使 torch.jit.script
按预期工作。
在PyTorch中结合使用JIT和Vectorizing Map#
根据上述介绍,当 torch.jit.trace
和 torch.jit.script
与 torch.vmap
结合使用时,由于兼容性考虑,需要进行协调。
下图说明了 torch.jit.script
、torch.jit.trace
和 torch.vmap
之间的关系,突出显示了它们的相互调用路径。如果模块 A 调用模块 B,这意味着 B 可以被 A 调用。

有关在 PyTorch 上使用 JIT 和矢量化映射的详细信息,请参阅 PyTorch 官方文档中的 TorchScript 和 torch.vmap
。
EvoX中的特定优化支持#
在 EvoX 中,大多数函数是在类中定义的,特别是 ModuleBase
的子类中。为了提供更全面的优化支持,EvoX 提供了特定的增强功能。
使用 JIT 到 ModuleBase
的子类#
为了更好地理解这一部分,我们需要解释EvoX中的三个重要函数:jit_class
、vmap
和jit
。
jit_class
函数#
jit_class
是一个辅助函数,用于对输入类的所有成员方法进行 Just-In-Time (JIT) 脚本化的 torch.jit.script
或跟踪 (torch.jit.trace_module
)。
jit_class
有两个参数:
cls
: 原始类,其成员方法将被懒惰 JIT。trace
: 是否追踪模块或将模块脚本化。默认值为False
。
备注
在许多情况下,不需要用
jit_class
包装您的自定义算法或问题,工作流将为您解决问题。在
trace=True
的情况下,所有成员函数都会被有效地修改为额外返回self
,因为副作用无法被追踪。如果你想保留副作用,请设置trace=False
并使用use_state
函数来包装成员方法以生成纯函数(use_state
函数将在下一部分中解释)。类似地,所有模块级的操作如
self.to(...)
只能返回未包装的模块,这可能不是我们想要的。由于它们大多是就地操作,可以使用简单的module.to(...)
来替代module = module.to(...)
。
vmap
函数#
vmap
函数将给定函数矢量化映射到其映射版本。基于 torch.vmap
,我们进行了许多改进,您可以查看 torch.vmap
以获取更多信息。
jit
函数#
jit
使用 torch.jit.trace
(trace=True
) 或 torch.jit.script
(trace=False
) 编译给定的 `func
该函数包装器有效处理嵌套的 JIT 和向量映射 (vmap
) 表达式,如 jit(func1)
-> vmap
-> jit(func2)
,从而防止可能出现的错误。
备注
在
trace=True
的情况下,torch.jit.trace
不能对具有不同参数的函数使用相同的示例输入参数,例如,你不能将tensor_a, tensor_a
传递给torch.jit.trace
版本的 `f(x: torch.Tensor, y: torch.Tensor)在
trace=False
的情况下,torch.jit.script
不能直接包含vmap
表达式,请使用jit(..., trace=True)
或torch.jit.trace
包装它们。
在 EvoX 中的模块工作 中,我们简要介绍了关于 ModuleBase
子类中方法的一些规则。现在 jit_class
、vmap
和 jit
已经被解释,我们将解释更多规则并提供一些具体提示。
子类中静态方法的定义#
在子类中,静态方法要被 JIT 定义为:
# Import Pytorch
import torch
# Import the ModuleBase class from EvoX
from evox.core import ModuleBase, jit
# Set an module inherited from the ModuleBase class
class ExampleModule(ModuleBase):
...
# One example of the static method defined in a Module
@jit
def func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
...
在子类中定义非静态方法#
如果一个方法使用了Python动态控制流,例如if
,并且需要进行JIT,那么应使用一个独立的静态方法,并使用jit(..., trace=False)
或`torch.jit.script_if_tracing
# Import Pytorch
import torch
# Import the ModuleBase class from EvoX
from evox.core import ModuleBase, jit
# Set an module inherited from the ModuleBase class
class ExampleModule(ModuleBase):
...
# An example of one method with python dynamic control flows like "if"
# The method using jit(..., trace=False)
@partial(jit, trace=False)
def static_func(x: torch.Tensor, threshold: float) -> torch.Tensor:
if x.flatten()[0] > threshold:
return torch.sin(x)
else:
return torch.tan(x)
# The method to be JIT
@jit
def jit_func(self, p: torch.Tensor) -> torch.Tensor:
return ExampleModule.static_func(p, self.threshold)
...
备注
在Python中,动态控制流是指根据运行时的条件动态变化的控制结构。if...elif...else
条件语句、for
循环和 while
循环都是动态控制流。如果在定义 ModuleBase
子类中的非静态方法时必须使用它们,请遵循上述规则。
调用子类中的外部方法#
在子类中,外部 JIT 方法可以通过类方法调用以实现 JIT:
# Import the ModuleBase class from EvoX
from evox.core import ModuleBase
# One example of the JIT method defined outside the module
@jit
def external_func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
# Set an module inherited from the ModuleBase class
class ExampleModule(ModuleBase):
...
# The internal method using jit(..., trace=False)
@partial(jit, trace=False)
def static_func(x: torch.Tensor, threshold: float) -> torch.Tensor:
# The internal static method to be JIT
@jit
def jit_func(self, p: torch.Tensor) -> torch.Tensor:
return external_func(p, p)
...
自动为使用 jit_class
的子类进行 JIT 处理#
ModuleBase
及其子类通常与 jit_class
一起使用,以自动JIT所有非魔术成员方法:
# Import Pytorch
import torch
# Import the ModuleBase class from EvoX
from evox.core import ModuleBase, jit_class
@jit_class
class ExampleModule(ModuleBase):
...
# This function will be automatically JIT
def automatically_JIT_func1(self, x: torch.Tensor) -> torch.Tensor:
pass
# Use `torch.jit.ignore` to disable JIT and leave this function as Python callback
@torch.jit.ignore
def no_JIT_func2(self, x: torch.Tensor) -> torch.Tensor:
# you can implement pure Python logic here
pass
# JIT functions can invoke other JIT functions as well as non-JIT functions
def automatically_JIT_func3(self, x: torch.Tensor) -> torch.Tensor:
y = self.automatically_JIT_func1(x)
z = self.no_JIT_func2(x)
pass
...
在子类中调用外部 Vmap 包装的方法#
在子类内部,外部 vmap 包装的方法可以通过类方法调用以实现 JIT:
# Import Pytorch
import torch
# Import the ModuleBase class from EvoX
from evox.core import ModuleBase, jit, vmap
# The method to be vmap-wrapped
def external_func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y.sum()
external_vmap_func = vmap(external_func, in_dims=1, out_dims=1)
# Set an module inherited from the ModuleBase class
class ExampleModule(ModuleBase):
...
# The internal class method to be JIT
@jit
def jit_func(self, p: torch.Tensor) -> torch.Tensor:
return external_vmap_func(p, p)
...
备注
如果方法 A 调用 vmap 包装的方法 B,那么 A 和所有调用方法 A 的方法不能再次被 vmap 包装。
子类内部的 Vmap 包装方法#
在子类内部,可以使用 trace_impl
对内部 vmap 包装的方法进行 JIT。
# Import Pytorch
import torch
# Import the ModuleBase class from EvoX
from evox.core import ModuleBase, jit, vmap, trace_impl
# Set an module inherited from the ModuleBase class
class ExampleModule(ModuleBase):
...
# The internal vmap-wrapped class method to be JIT
@jit
def jit_vmap_func(self, p: torch.Tensor) -> torch.Tensor:
# The original method
# We can not vmap it
def func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
# The method to be vmap-wrapped
# We need to use trace_impl to rewrite the original method
@trace_impl(func)
def trace_func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
pass
return vmap(func, in_dims=1, out_dims=1, trace=False)(p, p)
...
备注
如果一个类方法使用了trace_impl
,它将仅在追踪模式下可用。关于trace_impl
的更多细节将在下一部分展示。
使用 @trace_impl
和 `@vmap_impl#
在设计函数或方法时,您可能并不总是考虑它是否与JIT
兼容。然而,这一特性在特定场景中变得至关重要,例如解决超参数优化(HPO)问题。有关使用EvoX部署HPO的更多详细信息,请参阅Efficient HPO with EvoX。
这类问题的一个典型特征是,只需要修改算法的某些部分——例如,算法的 step
方法。这使您可以避免重写整个算法。在这种情况下,您可以使用 @trace_impl
或 @vmap_impl
装饰器,将函数重写为指定 target
方法的 trace-JIT-time 或 vmap-JIT-time 代理。
装饰器 @trace_impl
和 @vmap_impl
接受一个输入参数:在不进行追踪/vmap JIT 时调用的目标方法。这些装饰器仅适用于 jit_class
中的成员方法。
由于注释函数作为目标函数的重写版本,它必须保持相同的输入/输出签名(例如,参数的数量和类型)。否则,结果行为是未定义的。
如果注释的函数打算与vmap
一起使用,它必须满足三个额外的约束:
禁止对属性进行原地操作: 算法不得包含对其属性执行原地操作的方法。
class ExampleAlgorithm(Algorithm):
def __init__(self, ...):
self.pop = torch.rand(10, 10) # Attribute of the algorithm
def step_in_place(self): # Method with in-place operations
self.pop.copy_(pop)
def step_out_of_place(self): # Method without in-place operations
self.pop = pop
避免使用 Python 控制流: 代码逻辑不能依赖于 Python 控制流结构。要处理 Python 控制流,请使用
TracingCond
、TracingWhile
和TracingSwitch
。
@jit_class
class ExampleAlgorithm(Algorithm):
def __init__(self, pop_size, ...):
super().__init__()
self.pop = torch.rand(pop_size, pop_size)
def strategy_1(self): # One update strategy
new_pop = self.pop * self.pop
self.pop = new_pop
def strategy_2(self): # Another update strategy
new_pop = self.pop + self.pop
self.pop = new_pop
def step(self):
control_number = torch.rand()
if control_number < 0.5: # Conditional control
self.strategy_1()
else:
self.strategy_2()
@trace_impl(step) # Rewrite step function for vmap support
def trace_step_without_operations_to_self(self):
pop = torch.rand(self.pop_size, self.dim, dtype=self.lb.dtype, device=self.lb.device)
pop = pop * (self.ub - self.lb)[None, :] + self.lb[None, :]
pop = pop * self.hp[0]
control_number = torch.rand()
cond = control_number < 0.5
branches = (self.strategy_1, self.strategy_2)
state, names = self.prepare_control_flow(*branches) # Utilize state to track self.pop
_if_else_ = TracingCond(*branches)
state = _if_else_.cond(state, cond, pop)
self.after_control_flow(state, *names)
@trace_impl(step)
def trace_step_with_operations_to_self(self):
pop = torch.rand(self.pop_size, self.dim, dtype=self.lb.dtype, device=self.lb.device)
pop = pop * (self.ub - self.lb)[None, :] + self.lb[None, :]
pop = pop * self.hp[0]
control_number = torch.rand()
cond = control_number < 0.5
_if_else_ = TracingCond(lambda p: p * p, lambda p: p + p) # No need to track self.pop
pop = _if_else_.cond(cond, pop)
self.pop = pop
避免对
self
进行就地操作: 向量化映射对self
的就地操作定义不明确,无法编译。即使成功编译,您仍可能会悄无声息地得到错误的结果。
使用 use_state
#
use_state
将给定的有状态函数(对 nn.Module
进行就地更改)转换为纯函数版本,该版本接收一个额外的 state
参数(类型为 Dict[str, torch.Tensor]
),并返回更改后的状态。
输入 func
是要转换的有状态函数或其生成器函数,而 is_generator
指定 func
是一个函数还是一个函数生成器(例如,返回有状态函数的 lambda)。默认值为 `True
以下是一个简单的例子:
@jit_class
class Example(ModuleBase):
def __init__(self, threshold=0.5):
super().__init__()
self.threshold = threshold
self.sub_mod = nn.Module()
self.sub_mod.buf = nn.Buffer(torch.zeros(()))
def h(self, q: torch.Tensor) -> torch.Tensor:
if q.flatten()[0] > self.threshold:
x = torch.sin(q)
else:
x = torch.tan(q)
x += self.g(x).abs()
x *= x.shape[1]
self.sub_mod.buf = x.sum()
return x
@trace_impl(h)
def th(self, q: torch.Tensor) -> torch.Tensor:
x += self.g(x).abs()
x *= x.shape[1]
self.sub_mod.buf = x.sum()
return x
def g(self, p: torch.Tensor) -> torch.Tensor:
x = torch.cos(p)
return x * p.shape[0]
fn = use_state(lambda: t.h, is_generator=True)
jit_fn = jit(fn, trace=True, lazy=True)
results = jit_fn(fn.init_state(), torch.rand(10, 1))
print(results) # ({"self.sub_mod.buf": torch.Tensor(5.6)}, torch.Tensor([[0.56], ...]))
# IN-PLACE update all relevant variables using the given state
fn.set_state(results[0])
使用 `core._vmap_fix#
模块 _vmap_fix
提供了有用的函数。在自动导入后,_vmap_fix
使 torch.vmap
能够被 torch.jit.trace
正确追踪,同时解决了在 vmap
过程中无法正确追踪的随机数处理等问题。它还提供了 debug_print
函数,允许在 vmap
和追踪过程中动态打印 Tensor 值。
详细信息可以在_vmap_fix
文档中找到。