evox.core.module#

Module Contents#

Classes#

ModuleBase

The base module for all algorithms and problems in the library.

_WrapClassBase

UseStateFunc

Functions#

_if_none

_is_magic

Parameter

Wraps a value as parameter with requires_grad=False.

Mutable

Wraps a value as a mutable tensor.

assign_load_state_dict

Copy parameters and buffers from state_dict into this module and its descendants.

use_state_context

A context manager to set the value of using_state temporarily.

trace_caching_state_context

A context manager to set the value of trace_caching_state temporarily.

is_using_state

Get the current state of the using_state.

is_trace_caching_state

Get the current state of the trace_caching_state.

tracing_or_using_state

Check if we are currently JIT tracing (inside a torch.jit.trace), in a use_state_context, or in a trace_caching_state.

_get_vars

use_state

Transform the given stateful function (which in-place alters nn.Modules) to a pure-functional version that receives an additional state parameter (of type Dict[str, torch.Tensor]) and returns the altered state additionally.

trace_impl

A helper function used to annotate that the wrapped method shall be treated as a trace-JIT-time proxy of the given target method.

vmap_impl

A helper function used to annotate that the wrapped method shall be treated as a vmap-JIT-time proxy of the given target method.

jit_class

A helper function used to JIT script (torch.jit.script) or trace (torch.jit.trace_module) all member methods of class cls.

Data#

API#

evox.core.module._WRAPPING_MODULE_NAME#

wrapping_module

evox.core.module._if_none(a, b)[source]#
evox.core.module._is_magic(name: str)[source]#
evox.core.module.ParameterT#

‘TypeVar(…)’

evox.core.module.Parameter(value: evox.core.module.ParameterT, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, requires_grad: bool = False) evox.core.module.ParameterT[source]#

Wraps a value as parameter with requires_grad=False.

Parameters:
  • value – The parameter value.

  • dtype – The dtype of the parameter. Defaults to None.

  • device – The device of the parameter. Defaults to None.

  • requires_grad – Whether the parameter requires gradient. Defaults to False.

Returns:

The parameter.

evox.core.module.Mutable(value: torch.Tensor, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) torch.Tensor[source]#

Wraps a value as a mutable tensor.

Parameters:
  • value – The value to be wrapped.

  • dtype – The dtype of the tensor. Defaults to None.

  • device – The device of the tensor. Defaults to None.

Returns:

The wrapped tensor.

evox.core.module.assign_load_state_dict(self: torch.nn.Module, state_dict: Mapping[str, torch.Tensor])[source]#

Copy parameters and buffers from state_dict into this module and its descendants.

This method is used to mimic the behavior of ModuleBase.load_state_dict so that a regular nn.Module can be used with vmap.

Usage:

import types
# ...
model = ... # define your model
model.load_state_dict = types.MethodType(assign_load_state_dict, model)
vmap_forward = vmap(use_state(model.forward))
jit_forward = jit(vmap_forward, trace=True, example_inputs=(vmap_forward.init_state(), ...)) # JIT trace forward pass of the model
class evox.core.module.ModuleBase(*args, **kwargs)[source]#

Bases: torch.nn.Module

The base module for all algorithms and problems in the library.

Notice

  1. This module is an object-oriented one that can contain mutable values.

  2. Functional programming model is supported via self.state_dict(...) and self.load_state_dict(...).

  3. The module initialization for non-static members are recommended to be written in the overwritten method of setup (or any other member method) rather than __init__.

  4. Basically, predefined submodule(s) which will be ADDED to this module and accessed later in member method(s) should be treated as “non-static members”, while any other member(s) should be treated as “static members”.

Usage

  1. Static methods to be JIT shall be defined as is, e.g.,

@jit
def func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    pass
  1. If a class member function with python dynamic control flows like if were to be JIT, a separated static method with jit(..., trace=False) or torch.jit.script_if_tracing shall be used:

class ExampleModule(ModuleBase):
    def setup(self, mut: torch.Tensor):
        self.add_mutable("mut", mut)
        # or
        self.mut = Mutable(mut)
        return self

    @partial(jit, trace=False)
    def static_func(x: torch.Tensor, threshold: float) -> torch.Tensor:
        if x.flatten()[0] > threshold:
            return torch.sin(x)
        else:
            return torch.tan(x)
    @jit
    def jit_func(self, p: torch.Tensor) -> torch.Tensor:
        x = ExampleModule.static_func(p, self.threshold)
        ...
  1. ModuleBase is usually used with jit_class to automatically JIT all non-magic member methods:

@jit_class
class ExampleModule(ModuleBase):
    # This function will be automatically JIT
    def func1(self, x: torch.Tensor) -> torch.Tensor:
        pass

    # Use `torch.jit.ignore` to disable JIT and leave this function as Python callback
    @torch.jit.ignore
    def func2(self, x: torch.Tensor) -> torch.Tensor:
        # you can implement pure Python logic here
        pass

    # JIT functions can invoke other JIT functions as well as non-JIT functions
    def func3(self, x: torch.Tensor) -> torch.Tensor:
        y = self.func1(x)
        z = self.func2(x)
        pass

Initialization

Initialize the ModuleBase.

Parameters:
  • *args – Variable length argument list, passed to the parent class initializer.

  • **kwargs – Arbitrary keyword arguments, passed to the parent class initializer.

Attributes: static_names (list): A list to store static member names.

eval()[source]#
setup(*args, **kwargs)[source]#

Setup the module. Module initialization lines should be written in the overwritten method of setup rather than __init__.

Returns:

This module.

Notice

The static initialization can still be written in the __init__ while the mutable initialization cannot. Therefore, multiple calls of setup for multiple initializations are possible.

prepare_control_flow(*target_functions: Callable, keep_vars: bool = True) Tuple[Dict[str, torch.Tensor], Tuple[List[str], List[str]]][source]#

Prepares the control flow state of the module by collecting and merging the state and non-local variables from the specified target functions.

This function is used alongside with after_control_flow() to enable your control flow operations (utils.control_flow.*) deal with side-effects correctly. If the control flow operations have NO side-effects, you can safely ignore this function and after_control_flow().

Parameters:
  • target_functions – Functions whose non-local variables are to be collected.

  • keep_vars – See torch.nn.Module.state_dict(..., keep_vars). Defaults to True.

Returns:

A tuple containing the merged state dictionary, a list of state keys, and a list of non-local variable names.

Raises:

AssertionError – If not all target functions are local, global, or this class member functions

Warning

The non-local variables collected here can ONLY be used as read-only ones. In-place modifications to these variables may not raise any error and silently produce incorrect results.

Usage

@jit_class
def ExampleModule(ModuleBase):
    # define the normal `__init__` and `test` functions, etc.

    @trace_impl(test)
    def trace_test(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        self.q = self.q + 1
        local_q = self.q * 2

        def false_branch(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
            # nonlocal local_q ## These two lines may silently produce incorrect results
            # local_q *= 1.5
            return x * y * local_q # However, using read-only nonlocals is allowed

        state, keys = self.prepare_control_flow(self.true_branch, false_branch)
        if not hasattr(self, "_switch_"):
            self._switch_ = TracingSwitch(self.true_branch, false_branch)
        state, ret = self._switch_.switch(state, (x.flatten()[0] > 0).to(dtype=torch.int), x, y)
        self.after_control_flow(state, *keys)
        return ret
after_control_flow(state: Dict[str, torch.Tensor], state_keys: List[str], nonlocal_keys: List[str]) Dict[str, torch.Tensor][source]#

Restores the module state to the one before prepare_control_flow from the given state and returns the non-local variables collected in prepare_control_flow.

This function is used alongside with prepare_control_flow() to enable your control flow operations (utils.control_flow.*) deal with side-effects correctly. If the control flow operations have NO side-effects, you can safely ignore this function and prepare_control_flow().

Parameters:
  • state – The state dictionary to restore the module state from.

  • state_keys – The keys of the state dictionary that represent the module state.

  • nonlocal_keys – The keys of the state dictionary that represent the non-local variables.

Returns:

The non-local variables dictionary collected in prepare_control_flow.

Usage

See prepare_control_flow().

load_state_dict(state_dict: Mapping[str, torch.Tensor], copy: bool = False, **kwargs)[source]#

Copy parameters and buffers from state_dict into this module and its descendants. Overwrites torch.nn.Module.load_state_dict.

Parameters:
  • state_dict – A dict containing parameters and buffers used to update this module. See torch.nn.Module.load_state_dict.

  • copy – Use the original torch.nn.Module.load_state_dict to copy the state_dict to current state (copy=True) or use this implementation that assigns the values of this module to the ones in the state_dict (copy=False). Defaults to False.

  • **kwargs – The original arguments of torch.nn.Module.load_state_dict. Ignored if copy=False.

Returns:

If copy=True, returns the return of torch.nn.Module.load_state_dict; otherwise, no return.

add_mutable(name: str, value: Union[torch.Tensor | torch.nn.Module, Sequence[torch.Tensor | torch.nn.Module], Dict[str, torch.Tensor | torch.nn.Module]]) None[source]#

Define a mutable value in this module that can be accessed via self.[name] and modified in-place.

Parameters:
  • name – The mutable value’s name.

  • value – The mutable value, can be a tuple, list, dictionary of a torch.Tensor.

Raises:
  • NotImplementedError – If the mutable value’s type is not supported yet.

  • AssertionError – If the name is invalid.

to(*args, **kwargs) evox.core.module.ModuleBase[source]#
__getattribute__(name)[source]#
__getattr_inner__(name)[source]#
__delattr__(name)[source]#
__delattr_inner__(name)[source]#
__setattr__(name, value)[source]#
__setattr_inner__(name, value)[source]#
__getitem__(key: Union[int, slice, str]) Union[torch.Tensor, List[torch.Tensor]][source]#

Get the mutable value(s) stored in this list-like module.

Parameters:

key – The key used to index mutable value(s).

Raises:
  • IndexError – If key is out of range.

  • TypeError – If key is of wrong type.

Returns:

The indexed mutable value(s).

__setitem__(value: Union[torch.Tensor, List[torch.Tensor]], key: Union[slice, int]) None[source]#

Set the mutable value(s) stored in this list-like module.

Parameters:
  • value – The new mutable value(s).

  • key – The key used to index mutable value(s).

iter() Tuple[torch.Tensor][source]#
__sync_with__(jit_module: torch.jit.ScriptModule | None)[source]#
evox.core.module._using_state: contextvars.ContextVar[bool]#

‘ContextVar(…)’

evox.core.module._trace_caching_state: contextvars.ContextVar[bool]#

‘ContextVar(…)’

evox.core.module.use_state_context(new_use_state: bool = True)[source]#

A context manager to set the value of using_state temporarily.

When entering the context, the value of using_state is set to new_use_state and a token is obtained. When exiting the context, the value of using_state is reset to its previous value.

Parameters:

new_use_state – The new value of using_state. Defaults to True.

Examples:

>>> with use_state_context(True):
...     assert is_using_state()
>>> assert not is_using_state()
evox.core.module.trace_caching_state_context(new_trace_caching_state: bool = True)[source]#

A context manager to set the value of trace_caching_state temporarily.

When entering the context, the value of trace_caching_state is set to new_trace_caching_state and a token is obtained. When exiting the context, the value of trace_caching_state is reset to its previous value.

Parameters:

new_trace_caching_state – The new value of trace_caching_state. Defaults to True.

Examples:

>>> with trace_caching_state_context(True):
...     assert is_trace_caching_state()
>>> assert not is_trace_caching_state()
evox.core.module.is_using_state() bool[source]#

Get the current state of the using_state.

Returns:

The current state of the using_state.

evox.core.module.is_trace_caching_state() bool[source]#

Get the current state of the trace_caching_state.

Returns:

The current state of the trace_caching_state.

evox.core.module.tracing_or_using_state()[source]#

Check if we are currently JIT tracing (inside a torch.jit.trace), in a use_state_context, or in a trace_caching_state.

Returns:

True if either condition is true, False otherwise.

evox.core.module._SUBMODULE_PREFIX#

‘_submodule

class evox.core.module._WrapClassBase(inner: evox.core.module.ModuleBase)[source]#

Initialization

__str__() str[source]#
__repr__() str[source]#
__hash__() int[source]#
__format__(format_spec: str) str[source]#
__getitem__(key)[source]#
__setitem__(value, key)[source]#
__setattr__(name, value)[source]#
__delattr__(name)[source]#
__sync__()[source]#
evox.core.module._USE_STATE_NAME#

use_state

evox.core.module._STATE_ARG_NAME#

‘state’

class evox.core.module.UseStateFunc[source]#

Bases: typing.Protocol

is_empty_state: bool#

None

init_state(clone: bool = True) Dict[str, torch.Tensor][source]#

Get the cloned state of the closures of the function when it is wrapped by use_state.

Parameters:

clone – Whether to clone the original state or not. Defaults to True.

Returns:

The cloned state of the closures.

set_state(state: Optional[Dict[str, torch.Tensor]] = None) None[source]#

Set the closures of the function to the given state.

Parameters:

state – The new state to set to. If state=None, the new state would be the original state when the function is wrapped by use_state. Defaults to None.

__call__(state: Dict[str, torch.Tensor], *args, **kwargs) Dict[str, torch.Tensor] | Tuple[Dict[str, torch.Tensor], Any][source]#
evox.core.module._EMPTY_NAME#

empty

evox.core.module._get_vars(func: Callable, *exclude, is_generator: bool = True)[source]#
evox.core.module.use_state(func: Callable[[], Callable] | Callable, is_generator: bool = True) evox.core.module.UseStateFunc[source]#

Transform the given stateful function (which in-place alters nn.Modules) to a pure-functional version that receives an additional state parameter (of type Dict[str, torch.Tensor]) and returns the altered state additionally.

Parameters:
  • func – The stateful function to be transformed or its generator function.

  • is_generator – Whether func is a function or a function generator (e.g. a lambda that returns the stateful function). Defaults to True.

Returns:

The transformed pure-functional version of func. It contains a init_state() -> state attribute that returns the copy of the current state that func uses and can be used as example inputs of the additional state parameter. It also contains a set_state(state) attribute to set the global state to the given one (of course not JIT-compatible).

Notice

Since PyTorch cannot JIT or vectorized-map a function with empty dictionary, list, or tuple as its input, this function transforms the given function to a function WITHOUT the additional state parameter (of type Dict[str, torch.Tensor]) and does NOT return the altered state additionally.

Usage:

@jit_class
class Example(ModuleBase):

    def __init__(self, threshold=0.5):
        super().__init__()
        self.threshold = threshold
        self.sub_mod = nn.Module()
        self.sub_mod.buf = nn.Buffer(torch.zeros(()))

    def h(self, q: torch.Tensor) -> torch.Tensor:
        if q.flatten()[0] > self.threshold:
            x = torch.sin(q)
        else:
            x = torch.tan(q)
        x += self.g(x).abs()
        x *= x.shape[1]
        self.sub_mod.buf = x.sum()
        return x

    @trace_impl(h)
    def th(self, q: torch.Tensor) -> torch.Tensor:
        x += self.g(x).abs()
        x *= x.shape[1]
        self.sub_mod.buf = x.sum()
        return x

    def g(self, p: torch.Tensor) -> torch.Tensor:
        x = torch.cos(p)
        return x * p.shape[0]

t = Test()
fn = use_state(lambda: t.h, is_generator=True)
jit_fn = jit(fn, trace=True, lazy=True)
results = jit_fn(fn.init_state(), torch.rand(10, 1)
# print results, may be
print(results) # ({"self.sub_mod.buf": torch.Tensor(5.6)}, torch.Tensor([[0.56], ...]))
# IN-PLACE update all relevant variables using the given state, which is the variable `t` here.
fn.set_state(results[0])
evox.core.module._TORCHSCRIPT_MODIFIER#

‘_torchscript_modifier’

evox.core.module._TRACE_WRAP_NAME#

trace_wrapped

evox.core.module.T#

‘TypeVar(…)’

evox.core.module.trace_impl(target: Callable)[source]#

A helper function used to annotate that the wrapped method shall be treated as a trace-JIT-time proxy of the given target method.

Can ONLY be used inside a jit_class for a member method.

Parameters:

target – The target method invoked when not tracing JIT.

Returns:

The wrapping function to annotate the member method.

Notice

  1. The target function and the annotated function MUST have same input/output signatures (e.g. number of arguments and types); otherwise, the resulting behavior is UNDEFINED.

  2. If the annotated function are to be vmap, it cannot contain any in-place operations to self since such operations are not well-defined and cannot be compiled.

Usage:

See use_state.

evox.core.module._VMAP_WRAP_NAME#

vmap_wrapped

evox.core.module.vmap_impl(target: Callable)[source]#

A helper function used to annotate that the wrapped method shall be treated as a vmap-JIT-time proxy of the given target method.

Can ONLY be used inside a jit_class for a member method.

Parameters:

target – The target method invoked when not tracing JIT.

Returns:

The wrapping function to annotate the member method.

Notice

  1. The target function and the annotated function MUST have same input/output signatures (e.g. number of arguments and types); otherwise, the resulting behavior is UNDEFINED.

  2. If the annotated function are to be vmap, it cannot contain any in-place operations to self since such operations are not well-defined and cannot be compiled.

Usage:

See use_state.

evox.core.module.ClassT#

‘TypeVar(…)’

evox.core.module._BASE_NAME#

‘base’

evox.core.module.jit_class(cls: evox.core.module.ClassT, trace: bool = False) evox.core.module.ClassT[source]#

A helper function used to JIT script (torch.jit.script) or trace (torch.jit.trace_module) all member methods of class cls.

Parameters:
  • cls – The original class whose member methods are to be lazy JIT.

  • trace – Whether to trace the module or to script the module. Default to False.

Returns: The wrapped class.

Notice

  1. In many cases, it is not necessary to wrap your custom algorithms or problems with jit_class, the workflow(s) will do the trick for you.

  2. With trace=True, all the member functions are effectively modified to return self additionally since side-effects cannot be traced. If you want to preserve the side effects, please set trace=False and use the use_state function to wrap the member method to generate pure-functional

  3. Similarly, all module-wide operations like self.to(...) can only returns the unwrapped module, which may not be desired. Since most of them are in-place operations, a simple module.to(...) can be used instead of module = module.to(...).

Usage:

@jit_class
class Example(ModuleBase):
    # magic methods are ignored in JIT
    def __init__(self, threshold = 0.5):
        super().__init__()
        self.threshold = threshold

    # `torch.jit.export` is automatically added to this member method
    def h(self, q: torch.Tensor) -> torch.Tensor:
        if q.flatten()[0] > self.threshold:
            x = torch.sin(q)
        else:
            x = torch.tan(q)
        return x * x.shape[1]

exp = Example(0.75)
print(exp.h(torch.rand(10, 2)))
# equivalent to
exp = torch.jit.trace_module(Example(0.75))
print(exp.h(torch.rand(10, 2)))