Optimizing Development for EvoX via PyTorch Advanced Techniques#

Basic Optimization Support for Functions in PyTorch#

PyTorch provides fundamental optimization support for functions, primarily through vectorizing map (vmap) operations and Just-In-Time (JIT) compilation. These techniques enable efficient batch processing and enhance execution performance, respectively. Introductions of these optimizations are provided in the following sections.

Batch Processing Support through Vectorizing Map in PyTorch#

Vectorizing map, implemented in PyTorch as torch.vmap, is a powerful tool that takes a callable function and returns a batched version of it. According to specified strategy, this new function vectorizes the operations of the original one, which facilitates efficient batch processing. In EvoX, for example, this feature plays a crucial role in hyperparameter optimization (HPO).

import torch


def dummy_evaluation(pop_x: torch.Tensor, y: torch.Tensor):
    return pop_x * y


batched_dummy_evaluation = torch.vmap(dummy_evaluation, (0, None))

population_size = 3
individual_vector_size = 9
pop_x = torch.arange(individual_vector_size).repeat(population_size, 1)
y = torch.arange(individual_vector_size)

batched_dummy_evaluation(pop_x, y)
tensor([[ 0,  1,  4,  9, 16, 25, 36, 49, 64],
        [ 0,  1,  4,  9, 16, 25, 36, 49, 64],
        [ 0,  1,  4,  9, 16, 25, 36, 49, 64]])

Just-In-Time (JIT) Support in PyTorch#

In PyTorch, torch.jit.trace and torch.jit.script provide two distinct types of JIT tools, supporting function performance optimization through tracing and scripting, respectively.

Based on the tracing strategy, the torch.jit.trace method offers higher parsing speed and broader compatibility, such as with torch.vmap operations. Although it provides excellent support for simple functions, it is not suitable for complex tasks involving dynamic if-else branches and loop control flows.

import functools


@functools.partial(torch.vmap, in_dims=(0, None))
def vmap_sample_func(x: torch.Tensor, y: torch.Tensor):
    return x.sum() + y

In the example below, the traced vmap function successfully returns the correct code representation:

traced_vmap_func = torch.jit.trace(vmap_sample_func, example_inputs=(pop_x, y))
print(traced_vmap_func.code)
def vmap_sample_func(x: Tensor,
    y: Tensor) -> Tensor:
  _0 = torch.add(torch.view(torch.sum(x, [1]), [3, 1]), y)
  return _0

However, dynamic python control-flow cannot be traced correctly and a warning will be raised:

def dynamic_control_flow(pop_x: torch.Tensor, y: torch.Tensor):
    if y.flatten()[0] > 0:
        return pop_x + y[None, :]
    else:
        return pop_x * y[None, :]

traced_dynamic_control_flow_func = torch.jit.trace(dynamic_control_flow, example_inputs=(pop_x, y))
print(traced_dynamic_control_flow_func.code)
def dynamic_control_flow(pop_x: Tensor,
    y: Tensor) -> Tensor:
  y0 = torch.flatten(y)
  _0 = torch.slice(torch.unsqueeze(y0, 0), 1, 0, 9223372036854775807)
  return torch.mul(pop_x, _0)
C:\Users\skk77\AppData\Local\Temp\ipykernel_19564\349648791.py:2: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if y.flatten()[0] > 0:

Alternatively, the torch.jit.script method, which adopts a scripting strategy, is better suited for complex tasks that involve dynamic control flows but has limited compatibility.

In this example, the same vmap_sample_func function, after being scripted, returns an incorrect code representation:

scripted_vmap_func = torch.jit.script(vmap_sample_func)
print(scripted_vmap_func.code)
def vmap_sample_func(x: Tensor,
    y: Tensor) -> Tensor:
  return torch.add(torch.sum(x), y)

Yet, it can correctly deal with complex dynamic python control flow:

def dynamic_control_flow(pop_x: torch.Tensor, y: torch.Tensor):
    if y.flatten()[0] > 0:
        return pop_x + y[None, :]
    else:
        return pop_x * y[None, :]

script_dynamic_control_flow_func = torch.jit.script(dynamic_control_flow)
print(script_dynamic_control_flow_func.code)
def dynamic_control_flow(pop_x: Tensor,
    y: Tensor) -> Tensor:
  _0 = torch.gt(torch.select(torch.flatten(y), 0, 0), 0)
  if bool(_0):
    _2 = torch.slice(torch.unsqueeze(y, 0), 1)
    _1 = torch.add(pop_x, _2)
  else:
    _3 = torch.slice(torch.unsqueeze(y, 0), 1)
    _1 = torch.mul(pop_x, _3)
  return _1

Note

torch.jit.script relies on type hint to work properly. For example, any unannotated input argument is treated as a torch.Tensor while you can annotate some input arguments to be python types to make torch.jit.script work as intended.

Combined Usage of JIT and Vectorizing Map in PyTorch#

Based on the introductions above, when torch.jit.trace and torch.jit.script are used in combination with torch.vmap, coordination is required due to compatibility considerations.

The figure below illustrates the relationship between torch.jit.script, torch.jit.trace, and torch.vmap, highlighting their mutual invocation paths. If module A invokes module B, it implies that B can be called by A.

JIT introduction

For detailed usage of JIT and vectorizing map on PyTorch, please refer to the official PyTorch documentation for TorchScript and torch.vmap.

Specific Optimization Support in EvoX#

Within EvoX, most functions are defined inside classes, particularly subclasses of ModuleBase. To provide more comprehensive optimization supports, EvoX offers specific enhancements.

Using JIT to Subclasses of ModuleBase#

For better understanding of this part, we need to explain three important functions in EvoX: jit_class, vmap and jit.

jit_class Function#

jit_class is a helper function used to Just-In-Time (JIT) script of torch.jit.script or trace (torch.jit.trace_module) all member methods of the input class.

jit_class has two parameters:

  • cls: the original class whose member methods are to be lazy JIT.

  • trace: whether to trace the module or to script the module. Default to False.

Note

  1. In many cases, it is not necessary to wrap your custom algorithms or problems with jit_class, the workflow(s) will do the trick for you.

  2. With trace=True, all the member functions are effectively modified to return self additionally since side-effects cannot be traced. If you want to preserve the side effects, please set trace=False and use the use_state function to wrap the member method to generate pure-functional (the use_state function will be explained in the next part).

  3. Similarly, all module-wide operations like self.to(...) can only returns the unwrapped module, which may not be desired. Since most of them are in-place operations, a simple module.to(...) can be used instead of module = module.to(...).

vmap Function#

vmap function vectorized map the given function to its mapped version. Based on torch.vmap, we made many improvements, and you can see torch.vmap for more information.

jit Function#

jit compile the given func via torch.jit.trace (trace=True) or torch.jit.script (trace=False).

This function wrapper effectively deals with nested JIT and vector map (vmap) expressions like jit(func1) -> vmap -> jit(func2), preventing possible errors.

Note

  1. With trace=True, torch.jit.trace cannot use SAME example input arguments for function of DIFFERENT parameters,e.g., you cannot pass tensor_a, tensor_a to torch.jit.traced version of f(x: torch.Tensor, y: torch.Tensor).

  2. With trace=False, torch.jit.script cannot contain vmap expressions directly, please wrap them with jit(..., trace=True) or torch.jit.trace.

In the Working with Module in EvoX, we have briefly introduced some rules about the methods inside a subclass of the ModuleBase . Now that jit_class, vmap and jit have been explained, we will explain more rules and provide some specific hints.

Definition of Static Methods Inside the Subclass#

Inside the subclass, static methods to be JIT shall be defined like:

# Import Pytorch
import torch

# Import the ModuleBase class from EvoX
from evox.core import ModuleBase, jit

# Set an module inherited from the ModuleBase class
class ExampleModule(ModuleBase):
    
    ...
    
    # One example of the static method defined in a Module 
    @jit
    def func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        return x + y
    
    ...
    

Definition of Non-static Methods Inside the Subclass#

If a method with Python dynamic control flows like if were to be JIT, a separated static method with jit(..., trace=False) or torch.jit.script_if_tracing shall be used:

# Import Pytorch
import torch

# Import the ModuleBase class from EvoX
from evox.core import ModuleBase, jit

# Set an module inherited from the ModuleBase class
class ExampleModule(ModuleBase):
    
    ...
    
    # An example of one method with python dynamic control flows like "if"
    # The method using jit(..., trace=False)
    @partial(jit, trace=False)
    def static_func(x: torch.Tensor, threshold: float) -> torch.Tensor:
        if x.flatten()[0] > threshold:
            return torch.sin(x)
        else:
            return torch.tan(x)
        
    # The method to be JIT   
    @jit
    def jit_func(self, p: torch.Tensor) -> torch.Tensor:
        return ExampleModule.static_func(p, self.threshold)
    
    ...
    

Note

Dynamic control flow in Python refers to control structures that change dynamically based on conditions at runtime. if...elif...else Conditional Statements, forloop and while loop are all dynamic control flows. If you have to use them when defining non-static Methods inside the subclass of ModuleBase, please follow the above rule.

Invocation of External Methods Inside the Subclass#

Inside the subclass, external JIT methods can be invocated by the class methods to be JIT:

# Import the ModuleBase class from EvoX
from evox.core import ModuleBase

# One example of the JIT method defined outside the module 
@jit
def external_func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return x + y

# Set an module inherited from the ModuleBase class
class ExampleModule(ModuleBase):
    
    ...

    # The internal method using jit(..., trace=False)
    @partial(jit, trace=False)
    def static_func(x: torch.Tensor, threshold: float) -> torch.Tensor:

        
    # The internal static method to be JIT   
    @jit
    def jit_func(self, p: torch.Tensor) -> torch.Tensor:
        return external_func(p, p)
    
    ...
    

Automatically JIT for the Subclass Used with jit_class#

ModuleBase and its subclasses are usually used with jit_class to automatically JIT all non-magic member methods:

# Import Pytorch
import torch

# Import the ModuleBase class from EvoX
from evox.core import ModuleBase, jit_class

@jit_class
class ExampleModule(ModuleBase):
    
    ...
    
    # This function will be automatically JIT
    def automatically_JIT_func1(self, x: torch.Tensor) -> torch.Tensor:
        pass

    # Use `torch.jit.ignore` to disable JIT and leave this function as Python callback
    @torch.jit.ignore
    def no_JIT_func2(self, x: torch.Tensor) -> torch.Tensor:
        # you can implement pure Python logic here
        pass

    # JIT functions can invoke other JIT functions as well as non-JIT functions
    def automatically_JIT_func3(self, x: torch.Tensor) -> torch.Tensor:
        y = self.automatically_JIT_func1(x)
        z = self.no_JIT_func2(x)
        pass
    
    ...
    

Invocation of External Vmap-wrapped Methods Inside the Subclass#

Inside the subclass, external vmap-wrapped methods can be invocated by the class methods to be JIT:

# Import Pytorch
import torch

# Import the ModuleBase class from EvoX
from evox.core import ModuleBase, jit, vmap


# The method to be vmap-wrapped
def external_func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return x + y.sum()

external_vmap_func = vmap(external_func, in_dims=1, out_dims=1)


# Set an module inherited from the ModuleBase class
class ExampleModule(ModuleBase):
    
    ...    
    
    # The internal class method to be JIT   
    @jit
    def jit_func(self, p: torch.Tensor) -> torch.Tensor:
        return external_vmap_func(p, p)
    
    ...
    

Note

If method A invokes vmap-wrapped method B, then A and all methods invoke method A can not be vmap-wrapped again.

Internal Vmap-wrapped Methods Inside the Subclass#

Inside the subclass, internal vmap-wrapped methods can be JIT by using the trace_impl:

# Import Pytorch
import torch

# Import the ModuleBase class from EvoX
from evox.core import ModuleBase, jit, vmap, trace_impl


# Set an module inherited from the ModuleBase class
class ExampleModule(ModuleBase):
    
    ...    
    
    # The internal vmap-wrapped class method to be JIT   
    @jit
    def jit_vmap_func(self, p: torch.Tensor) -> torch.Tensor:
        
        # The original method
        # We can not vmap it
    	def func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    		return x + y
        
        
        # The method to be vmap-wrapped
        # We need to use trace_impl to rewrite the original method
        @trace_impl(func)
    	def trace_func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    		pass
        
        return vmap(func, in_dims=1, out_dims=1, trace=False)(p, p)
    
    ...
    

Note

If a class method use trace_impl, it will be only available in the trace mode. More details about trace_impl will be shown in the next part.

Using @trace_impl and @vmap_impl#

When designing a function or method, you may not always consider whether it is JIT-compatible. However, this property becomes crucial in specific scenarios, such as solving Hyperparameter Optimization (HPO) problems. For more details on deploying HPO with EvoX, refer to Efficient HPO with EvoX.

A typical characteristic of such problems is that only certain parts of the algorithm need modification—for instance, the step method of an algorithm. This allows you to avoid rewriting the entire algorithm. In such cases, you can use the @trace_impl or @vmap_impl decorator to rewrite the function as a trace-JIT-time or vmap-JIT-time proxy for the specified target method.

The decorators @trace_impl and @vmap_impl accept a single input parameter: the target method invoked when not tracing/vmapping JIT. These decorators are applicable only to member methods within a jit_class.

Since the annotated function serves as a rewritten version of the target function, it must maintain identical input/output signatures (e.g., number and types of arguments). Otherwise, the resulting behavior is undefined.

If the annotated function is intended for use with vmap, it must satisfy three additional constraints:

  1. No In-Place Operations on Attributes: The algorithm must not include methods that perform in-place operations on its attributes.

class ExampleAlgorithm(Algorithm):
    def __init__(self, ...):
        self.pop = torch.rand(10, 10)  # Attribute of the algorithm

    def step_in_place(self):  # Method with in-place operations
        self.pop.copy_(pop)

    def step_out_of_place(self):  # Method without in-place operations
        self.pop = pop
  1. Avoid Python Control Flow: The code logic must not rely on Python control flow structures. To handle Python control flow, use TracingCond, TracingWhile, and TracingSwitch.

@jit_class
class ExampleAlgorithm(Algorithm):
    def __init__(self, pop_size, ...):
        super().__init__()
        self.pop = torch.rand(pop_size, pop_size)

    def strategy_1(self):  # One update strategy
        new_pop = self.pop * self.pop
        self.pop = new_pop

    def strategy_2(self):  # Another update strategy
        new_pop = self.pop + self.pop
        self.pop = new_pop

    def step(self):
        control_number = torch.rand()
        if control_number < 0.5:  # Conditional control
            self.strategy_1()
        else:
            self.strategy_2()

    @trace_impl(step)  # Rewrite step function for vmap support
    def trace_step_without_operations_to_self(self):
        pop = torch.rand(self.pop_size, self.dim, dtype=self.lb.dtype, device=self.lb.device)
        pop = pop * (self.ub - self.lb)[None, :] + self.lb[None, :]
        pop = pop * self.hp[0]
        control_number = torch.rand()
        cond = control_number < 0.5
        branches = (self.strategy_1, self.strategy_2)
        state, names = self.prepare_control_flow(*branches)  # Utilize state to track self.pop
        _if_else_ = TracingCond(*branches)
        state = _if_else_.cond(state, cond, pop)
        self.after_control_flow(state, *names)

    @trace_impl(step)
    def trace_step_with_operations_to_self(self):
        pop = torch.rand(self.pop_size, self.dim, dtype=self.lb.dtype, device=self.lb.device)
        pop = pop * (self.ub - self.lb)[None, :] + self.lb[None, :]
        pop = pop * self.hp[0]
        control_number = torch.rand()
        cond = control_number < 0.5
        _if_else_ = TracingCond(lambda p: p * p, lambda p: p + p)  # No need to track self.pop
        pop = _if_else_.cond(cond, pop)
        self.pop = pop
  1. Avoid In-Place Operations on self: Vectorized map in-place operations on self are not well-defined and cannot be compiled. Even if it is compiled successfully, you can still silently get incorrect results.

Using use_state#

use_state transforms a given stateful function (which performs in-place alterations on nn.Modules) into a pure-functional version that receives an additional state parameter (of type Dict[str, torch.Tensor]) and returns the altered state.

The input func is the stateful function to be transformed or its generator function, and is_generator specifies whether func is a function or a function generator (e.g., a lambda that returns the stateful function). It defaults to True.

Here is a simple example:

@jit_class
class Example(ModuleBase):
    def __init__(self, threshold=0.5):
        super().__init__()
        self.threshold = threshold
        self.sub_mod = nn.Module()
        self.sub_mod.buf = nn.Buffer(torch.zeros(()))

    def h(self, q: torch.Tensor) -> torch.Tensor:
        if q.flatten()[0] > self.threshold:
            x = torch.sin(q)
        else:
            x = torch.tan(q)
        x += self.g(x).abs()
        x *= x.shape[1]
        self.sub_mod.buf = x.sum()
        return x

    @trace_impl(h)
    def th(self, q: torch.Tensor) -> torch.Tensor:
        x += self.g(x).abs()
        x *= x.shape[1]
        self.sub_mod.buf = x.sum()
        return x

    def g(self, p: torch.Tensor) -> torch.Tensor:
        x = torch.cos(p)
        return x * p.shape[0]

fn = use_state(lambda: t.h, is_generator=True)
jit_fn = jit(fn, trace=True, lazy=True)
results = jit_fn(fn.init_state(), torch.rand(10, 1))
print(results)  # ({"self.sub_mod.buf": torch.Tensor(5.6)}, torch.Tensor([[0.56], ...]))

# IN-PLACE update all relevant variables using the given state
fn.set_state(results[0])

Using core._vmap_fix#

The module _vmap_fix provides useful functions. After the automatic import, _vmap_fix enables torch.vmap to be correctly traced by torch.jit.trace, while resolving issues such as random number handling that couldn’t be properly traced during the vmap process. It also provides the debug_print function, which allows dynamic printing of Tensor values during both vmap and tracing.

Detailed information can be found in the _vmap_fix documentation.