Optimizing Development for EvoX via PyTorch Advanced Techniques#
Basic Optimization Support for Functions in PyTorch#
PyTorch provides fundamental optimization support for functions, primarily through vectorizing map (vmap) operations and Just-In-Time (JIT) compilation. These techniques enable efficient batch processing and enhance execution performance, respectively. Introductions of these optimizations are provided in the following sections.
Batch Processing Support through Vectorizing Map in PyTorch#
Vectorizing map, implemented in PyTorch as torch.vmap
, is a powerful tool that takes a callable function and returns a batched version of it. According to specified strategy, this new function vectorizes the operations of the original one, which facilitates efficient batch processing. In EvoX, for example, this feature plays a crucial role in hyperparameter optimization (HPO).
import torch
def dummy_evaluation(pop_x: torch.Tensor, y: torch.Tensor):
return pop_x * y
batched_dummy_evaluation = torch.vmap(dummy_evaluation, (0, None))
population_size = 3
individual_vector_size = 9
pop_x = torch.arange(individual_vector_size).repeat(population_size, 1)
y = torch.arange(individual_vector_size)
batched_dummy_evaluation(pop_x, y)
tensor([[ 0, 1, 4, 9, 16, 25, 36, 49, 64],
[ 0, 1, 4, 9, 16, 25, 36, 49, 64],
[ 0, 1, 4, 9, 16, 25, 36, 49, 64]])
Just-In-Time (JIT) Support in PyTorch#
In PyTorch, torch.jit.trace
and torch.jit.script
provide two distinct types of JIT tools, supporting function performance optimization through tracing and scripting, respectively.
Based on the tracing strategy, the torch.jit.trace
method offers higher parsing speed and broader compatibility, such as with torch.vmap
operations. Although it provides excellent support for simple functions, it is not suitable for complex tasks involving dynamic if-else branches and loop control flows.
import functools
@functools.partial(torch.vmap, in_dims=(0, None))
def vmap_sample_func(x: torch.Tensor, y: torch.Tensor):
return x.sum() + y
In the example below, the traced vmap
function successfully returns the correct code representation:
traced_vmap_func = torch.jit.trace(vmap_sample_func, example_inputs=(pop_x, y))
print(traced_vmap_func.code)
def vmap_sample_func(x: Tensor,
y: Tensor) -> Tensor:
_0 = torch.add(torch.view(torch.sum(x, [1]), [3, 1]), y)
return _0
However, dynamic python control-flow cannot be traced correctly and a warning will be raised:
def dynamic_control_flow(pop_x: torch.Tensor, y: torch.Tensor):
if y.flatten()[0] > 0:
return pop_x + y[None, :]
else:
return pop_x * y[None, :]
traced_dynamic_control_flow_func = torch.jit.trace(dynamic_control_flow, example_inputs=(pop_x, y))
print(traced_dynamic_control_flow_func.code)
def dynamic_control_flow(pop_x: Tensor,
y: Tensor) -> Tensor:
y0 = torch.flatten(y)
_0 = torch.slice(torch.unsqueeze(y0, 0), 1, 0, 9223372036854775807)
return torch.mul(pop_x, _0)
C:\Users\skk77\AppData\Local\Temp\ipykernel_19564\349648791.py:2: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if y.flatten()[0] > 0:
Alternatively, the torch.jit.script
method, which adopts a scripting strategy, is better suited for complex tasks that involve dynamic control flows but has limited compatibility.
In this example, the same vmap_sample_func
function, after being scripted, returns an incorrect code representation:
scripted_vmap_func = torch.jit.script(vmap_sample_func)
print(scripted_vmap_func.code)
def vmap_sample_func(x: Tensor,
y: Tensor) -> Tensor:
return torch.add(torch.sum(x), y)
Yet, it can correctly deal with complex dynamic python control flow:
def dynamic_control_flow(pop_x: torch.Tensor, y: torch.Tensor):
if y.flatten()[0] > 0:
return pop_x + y[None, :]
else:
return pop_x * y[None, :]
script_dynamic_control_flow_func = torch.jit.script(dynamic_control_flow)
print(script_dynamic_control_flow_func.code)
def dynamic_control_flow(pop_x: Tensor,
y: Tensor) -> Tensor:
_0 = torch.gt(torch.select(torch.flatten(y), 0, 0), 0)
if bool(_0):
_2 = torch.slice(torch.unsqueeze(y, 0), 1)
_1 = torch.add(pop_x, _2)
else:
_3 = torch.slice(torch.unsqueeze(y, 0), 1)
_1 = torch.mul(pop_x, _3)
return _1
Note
torch.jit.script
relies on type hint to work properly. For example, any unannotated input argument is treated as a torch.Tensor
while you can annotate some input arguments to be python types to make torch.jit.script
work as intended.
Combined Usage of JIT and Vectorizing Map in PyTorch#
Based on the introductions above, when torch.jit.trace
and torch.jit.script
are used in combination with torch.vmap
, coordination is required due to compatibility considerations.
The figure below illustrates the relationship between torch.jit.script
, torch.jit.trace
, and torch.vmap
, highlighting their mutual invocation paths. If module A invokes module B, it implies that B can be called by A.

For detailed usage of JIT and vectorizing map on PyTorch, please refer to the official PyTorch documentation for TorchScript and torch.vmap
.
Specific Optimization Support in EvoX#
Within EvoX, most functions are defined inside classes, particularly subclasses of ModuleBase
. To provide more comprehensive optimization supports, EvoX offers specific enhancements.
Using JIT to Subclasses of ModuleBase
#
For better understanding of this part, we need to explain three important functions in EvoX: jit_class
, vmap
and jit
.
jit_class
Function#
jit_class
is a helper function used to Just-In-Time (JIT) script of torch.jit.script
or trace (torch.jit.trace_module
) all member methods of the input class.
jit_class
has two parameters:
cls
: the original class whose member methods are to be lazy JIT.trace
: whether to trace the module or to script the module. Default toFalse
.
Note
In many cases, it is not necessary to wrap your custom algorithms or problems with
jit_class
, the workflow(s) will do the trick for you.With
trace=True
, all the member functions are effectively modified to returnself
additionally since side-effects cannot be traced. If you want to preserve the side effects, please settrace=False
and use theuse_state
function to wrap the member method to generate pure-functional (theuse_state
function will be explained in the next part).Similarly, all module-wide operations like
self.to(...)
can only returns the unwrapped module, which may not be desired. Since most of them are in-place operations, a simplemodule.to(...)
can be used instead ofmodule = module.to(...)
.
vmap
Function#
vmap
function vectorized map the given function to its mapped version. Based on torch.vmap
, we made many improvements, and you can see torch.vmap
for more information.
jit
Function#
jit
compile the given func
via torch.jit.trace
(trace=True
) or torch.jit.script
(trace=False
).
This function wrapper effectively deals with nested JIT and vector map (vmap
) expressions like jit(func1)
-> vmap
-> jit(func2)
, preventing possible errors.
Note
With
trace=True
,torch.jit.trace
cannot use SAME example input arguments for function of DIFFERENT parameters,e.g., you cannot passtensor_a, tensor_a
totorch.jit.trace
d version off(x: torch.Tensor, y: torch.Tensor)
.With
trace=False
,torch.jit.script
cannot containvmap
expressions directly, please wrap them withjit(..., trace=True)
ortorch.jit.trace
.
In the Working with Module in EvoX, we have briefly introduced some rules about the methods inside a subclass of the ModuleBase
. Now that jit_class
, vmap
and jit
have been explained, we will explain more rules and provide some specific hints.
Definition of Static Methods Inside the Subclass#
Inside the subclass, static methods to be JIT shall be defined like:
# Import Pytorch
import torch
# Import the ModuleBase class from EvoX
from evox.core import ModuleBase, jit
# Set an module inherited from the ModuleBase class
class ExampleModule(ModuleBase):
...
# One example of the static method defined in a Module
@jit
def func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
...
Definition of Non-static Methods Inside the Subclass#
If a method with Python dynamic control flows like if
were to be JIT, a separated static method with jit(..., trace=False)
or torch.jit.script_if_tracing
shall be used:
# Import Pytorch
import torch
# Import the ModuleBase class from EvoX
from evox.core import ModuleBase, jit
# Set an module inherited from the ModuleBase class
class ExampleModule(ModuleBase):
...
# An example of one method with python dynamic control flows like "if"
# The method using jit(..., trace=False)
@partial(jit, trace=False)
def static_func(x: torch.Tensor, threshold: float) -> torch.Tensor:
if x.flatten()[0] > threshold:
return torch.sin(x)
else:
return torch.tan(x)
# The method to be JIT
@jit
def jit_func(self, p: torch.Tensor) -> torch.Tensor:
return ExampleModule.static_func(p, self.threshold)
...
Note
Dynamic control flow in Python refers to control structures that change dynamically based on conditions at runtime.
if...elif...else
Conditional Statements, for
loop and while
loop are all dynamic control flows. If you have to use them when defining non-static Methods inside the subclass of ModuleBase
, please follow the above rule.
Invocation of External Methods Inside the Subclass#
Inside the subclass, external JIT methods can be invocated by the class methods to be JIT:
# Import the ModuleBase class from EvoX
from evox.core import ModuleBase
# One example of the JIT method defined outside the module
@jit
def external_func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
# Set an module inherited from the ModuleBase class
class ExampleModule(ModuleBase):
...
# The internal method using jit(..., trace=False)
@partial(jit, trace=False)
def static_func(x: torch.Tensor, threshold: float) -> torch.Tensor:
# The internal static method to be JIT
@jit
def jit_func(self, p: torch.Tensor) -> torch.Tensor:
return external_func(p, p)
...
Automatically JIT for the Subclass Used with jit_class
#
ModuleBase
and its subclasses are usually used with jit_class
to automatically JIT all non-magic member methods:
# Import Pytorch
import torch
# Import the ModuleBase class from EvoX
from evox.core import ModuleBase, jit_class
@jit_class
class ExampleModule(ModuleBase):
...
# This function will be automatically JIT
def automatically_JIT_func1(self, x: torch.Tensor) -> torch.Tensor:
pass
# Use `torch.jit.ignore` to disable JIT and leave this function as Python callback
@torch.jit.ignore
def no_JIT_func2(self, x: torch.Tensor) -> torch.Tensor:
# you can implement pure Python logic here
pass
# JIT functions can invoke other JIT functions as well as non-JIT functions
def automatically_JIT_func3(self, x: torch.Tensor) -> torch.Tensor:
y = self.automatically_JIT_func1(x)
z = self.no_JIT_func2(x)
pass
...
Invocation of External Vmap-wrapped Methods Inside the Subclass#
Inside the subclass, external vmap-wrapped methods can be invocated by the class methods to be JIT:
# Import Pytorch
import torch
# Import the ModuleBase class from EvoX
from evox.core import ModuleBase, jit, vmap
# The method to be vmap-wrapped
def external_func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y.sum()
external_vmap_func = vmap(external_func, in_dims=1, out_dims=1)
# Set an module inherited from the ModuleBase class
class ExampleModule(ModuleBase):
...
# The internal class method to be JIT
@jit
def jit_func(self, p: torch.Tensor) -> torch.Tensor:
return external_vmap_func(p, p)
...
Note
If method A invokes vmap-wrapped method B, then A and all methods invoke method A can not be vmap-wrapped again.
Internal Vmap-wrapped Methods Inside the Subclass#
Inside the subclass, internal vmap-wrapped methods can be JIT by using the trace_impl
:
# Import Pytorch
import torch
# Import the ModuleBase class from EvoX
from evox.core import ModuleBase, jit, vmap, trace_impl
# Set an module inherited from the ModuleBase class
class ExampleModule(ModuleBase):
...
# The internal vmap-wrapped class method to be JIT
@jit
def jit_vmap_func(self, p: torch.Tensor) -> torch.Tensor:
# The original method
# We can not vmap it
def func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
# The method to be vmap-wrapped
# We need to use trace_impl to rewrite the original method
@trace_impl(func)
def trace_func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
pass
return vmap(func, in_dims=1, out_dims=1, trace=False)(p, p)
...
Note
If a class method use trace_impl
, it will be only available in the trace mode. More details about trace_impl
will be shown in the next part.
Using @trace_impl
and @vmap_impl
#
When designing a function or method, you may not always consider whether it is JIT
-compatible. However, this property becomes crucial in specific scenarios, such as solving Hyperparameter Optimization (HPO) problems. For more details on deploying HPO with EvoX, refer to Efficient HPO with EvoX.
A typical characteristic of such problems is that only certain parts of the algorithm need modification—for instance, the step
method of an algorithm. This allows you to avoid rewriting the entire algorithm. In such cases, you can use the @trace_impl
or @vmap_impl
decorator to rewrite the function as a trace-JIT-time or vmap-JIT-time proxy for the specified target
method.
The decorators @trace_impl
and @vmap_impl
accept a single input parameter: the target method invoked when not tracing/vmapping JIT. These decorators are applicable only to member methods within a jit_class
.
Since the annotated function serves as a rewritten version of the target function, it must maintain identical input/output signatures (e.g., number and types of arguments). Otherwise, the resulting behavior is undefined.
If the annotated function is intended for use with vmap
, it must satisfy three additional constraints:
No In-Place Operations on Attributes: The algorithm must not include methods that perform in-place operations on its attributes.
class ExampleAlgorithm(Algorithm):
def __init__(self, ...):
self.pop = torch.rand(10, 10) # Attribute of the algorithm
def step_in_place(self): # Method with in-place operations
self.pop.copy_(pop)
def step_out_of_place(self): # Method without in-place operations
self.pop = pop
Avoid Python Control Flow: The code logic must not rely on Python control flow structures. To handle Python control flow, use
TracingCond
,TracingWhile
, andTracingSwitch
.
@jit_class
class ExampleAlgorithm(Algorithm):
def __init__(self, pop_size, ...):
super().__init__()
self.pop = torch.rand(pop_size, pop_size)
def strategy_1(self): # One update strategy
new_pop = self.pop * self.pop
self.pop = new_pop
def strategy_2(self): # Another update strategy
new_pop = self.pop + self.pop
self.pop = new_pop
def step(self):
control_number = torch.rand()
if control_number < 0.5: # Conditional control
self.strategy_1()
else:
self.strategy_2()
@trace_impl(step) # Rewrite step function for vmap support
def trace_step_without_operations_to_self(self):
pop = torch.rand(self.pop_size, self.dim, dtype=self.lb.dtype, device=self.lb.device)
pop = pop * (self.ub - self.lb)[None, :] + self.lb[None, :]
pop = pop * self.hp[0]
control_number = torch.rand()
cond = control_number < 0.5
branches = (self.strategy_1, self.strategy_2)
state, names = self.prepare_control_flow(*branches) # Utilize state to track self.pop
_if_else_ = TracingCond(*branches)
state = _if_else_.cond(state, cond, pop)
self.after_control_flow(state, *names)
@trace_impl(step)
def trace_step_with_operations_to_self(self):
pop = torch.rand(self.pop_size, self.dim, dtype=self.lb.dtype, device=self.lb.device)
pop = pop * (self.ub - self.lb)[None, :] + self.lb[None, :]
pop = pop * self.hp[0]
control_number = torch.rand()
cond = control_number < 0.5
_if_else_ = TracingCond(lambda p: p * p, lambda p: p + p) # No need to track self.pop
pop = _if_else_.cond(cond, pop)
self.pop = pop
Avoid In-Place Operations on
self
: Vectorized map in-place operations onself
are not well-defined and cannot be compiled. Even if it is compiled successfully, you can still silently get incorrect results.
Using use_state
#
use_state
transforms a given stateful function (which performs in-place alterations on nn.Module
s) into a pure-functional version that receives an additional state
parameter (of type Dict[str, torch.Tensor]
) and returns the altered state.
The input func
is the stateful function to be transformed or its generator function, and is_generator
specifies whether func
is a function or a function generator (e.g., a lambda that returns the stateful function). It defaults to True
.
Here is a simple example:
@jit_class
class Example(ModuleBase):
def __init__(self, threshold=0.5):
super().__init__()
self.threshold = threshold
self.sub_mod = nn.Module()
self.sub_mod.buf = nn.Buffer(torch.zeros(()))
def h(self, q: torch.Tensor) -> torch.Tensor:
if q.flatten()[0] > self.threshold:
x = torch.sin(q)
else:
x = torch.tan(q)
x += self.g(x).abs()
x *= x.shape[1]
self.sub_mod.buf = x.sum()
return x
@trace_impl(h)
def th(self, q: torch.Tensor) -> torch.Tensor:
x += self.g(x).abs()
x *= x.shape[1]
self.sub_mod.buf = x.sum()
return x
def g(self, p: torch.Tensor) -> torch.Tensor:
x = torch.cos(p)
return x * p.shape[0]
fn = use_state(lambda: t.h, is_generator=True)
jit_fn = jit(fn, trace=True, lazy=True)
results = jit_fn(fn.init_state(), torch.rand(10, 1))
print(results) # ({"self.sub_mod.buf": torch.Tensor(5.6)}, torch.Tensor([[0.56], ...]))
# IN-PLACE update all relevant variables using the given state
fn.set_state(results[0])
Using core._vmap_fix
#
The module _vmap_fix
provides useful functions. After the automatic import, _vmap_fix
enables torch.vmap
to be correctly traced by torch.jit.trace
, while resolving issues such as random number handling that couldn’t be properly traced during the vmap
process. It also provides the debug_print
function, which allows dynamic printing of Tensor values during both vmap
and tracing.
Detailed information can be found in the _vmap_fix
documentation.