evox.core.module
#
Module Contents#
Classes#
The base module for all algorithms and problems in the library. |
|
Functions#
Wraps a value as parameter with |
|
Wraps a value as a mutable tensor. |
|
Copy parameters and buffers from state_dict into this module and its descendants. |
|
A context manager to set the value of |
|
A context manager to set the value of |
|
Get the current state of the |
|
Get the current state of the |
|
Check if we are currently JIT tracing (inside a |
|
Transform the given stateful function (which in-place alters |
|
A helper function used to annotate that the wrapped method shall be treated as a trace-JIT-time proxy of the given |
|
A helper function used to annotate that the wrapped method shall be treated as a vmap-JIT-time proxy of the given |
|
A helper function used to JIT script ( |
Data#
API#
- evox.core.module._WRAPPING_MODULE_NAME#
‘wrapping_module’
- evox.core.module.ParameterT#
‘TypeVar(…)’
- evox.core.module.Parameter(value: evox.core.module.ParameterT, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, requires_grad: bool = False) evox.core.module.ParameterT [source]#
Wraps a value as parameter with
requires_grad=False
.- Parameters:
value – The parameter value.
dtype – The dtype of the parameter. Defaults to None.
device – The device of the parameter. Defaults to None.
requires_grad – Whether the parameter requires gradient. Defaults to False.
- Returns:
The parameter.
- evox.core.module.Mutable(value: torch.Tensor, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) torch.Tensor [source]#
Wraps a value as a mutable tensor.
- Parameters:
value – The value to be wrapped.
dtype – The dtype of the tensor. Defaults to None.
device – The device of the tensor. Defaults to None.
- Returns:
The wrapped tensor.
- evox.core.module.assign_load_state_dict(self: torch.nn.Module, state_dict: Mapping[str, torch.Tensor])[source]#
Copy parameters and buffers from state_dict into this module and its descendants.
This method is used to mimic the behavior of
ModuleBase.load_state_dict
so that a regularnn.Module
can be used withvmap
.Usage:
import types # ... model = ... # define your model model.load_state_dict = types.MethodType(assign_load_state_dict, model) vmap_forward = vmap(use_state(model.forward)) jit_forward = jit(vmap_forward, trace=True, example_inputs=(vmap_forward.init_state(), ...)) # JIT trace forward pass of the model
- class evox.core.module.ModuleBase(*args, **kwargs)[source]#
Bases:
torch.nn.Module
The base module for all algorithms and problems in the library.
Notice
This module is an object-oriented one that can contain mutable values.
Functional programming model is supported via
self.state_dict(...)
andself.load_state_dict(...)
.The module initialization for non-static members are recommended to be written in the overwritten method of
setup
(or any other member method) rather than__init__
.Basically, predefined submodule(s) which will be ADDED to this module and accessed later in member method(s) should be treated as “non-static members”, while any other member(s) should be treated as “static members”.
Usage
Static methods to be JIT shall be defined as is, e.g.,
@jit def func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: pass
If a class member function with python dynamic control flows like
if
were to be JIT, a separated static method withjit(..., trace=False)
ortorch.jit.script_if_tracing
shall be used:
class ExampleModule(ModuleBase): def setup(self, mut: torch.Tensor): self.add_mutable("mut", mut) # or self.mut = Mutable(mut) return self @partial(jit, trace=False) def static_func(x: torch.Tensor, threshold: float) -> torch.Tensor: if x.flatten()[0] > threshold: return torch.sin(x) else: return torch.tan(x) @jit def jit_func(self, p: torch.Tensor) -> torch.Tensor: x = ExampleModule.static_func(p, self.threshold) ...
ModuleBase
is usually used withjit_class
to automatically JIT all non-magic member methods:
@jit_class class ExampleModule(ModuleBase): # This function will be automatically JIT def func1(self, x: torch.Tensor) -> torch.Tensor: pass # Use `torch.jit.ignore` to disable JIT and leave this function as Python callback @torch.jit.ignore def func2(self, x: torch.Tensor) -> torch.Tensor: # you can implement pure Python logic here pass # JIT functions can invoke other JIT functions as well as non-JIT functions def func3(self, x: torch.Tensor) -> torch.Tensor: y = self.func1(x) z = self.func2(x) pass
Initialization
Initialize the ModuleBase.
- Parameters:
*args – Variable length argument list, passed to the parent class initializer.
**kwargs – Arbitrary keyword arguments, passed to the parent class initializer.
Attributes: static_names (list): A list to store static member names.
- setup(*args, **kwargs)[source]#
Setup the module. Module initialization lines should be written in the overwritten method of
setup
rather than__init__
.- Returns:
This module.
Notice
The static initialization can still be written in the
__init__
while the mutable initialization cannot. Therefore, multiple calls ofsetup
for multiple initializations are possible.
- prepare_control_flow(*target_functions: Callable, keep_vars: bool = True) Tuple[Dict[str, torch.Tensor], Tuple[List[str], List[str]]] [source]#
Prepares the control flow state of the module by collecting and merging the state and non-local variables from the specified target functions.
This function is used alongside with
after_control_flow()
to enable your control flow operations (utils.control_flow.*
) deal with side-effects correctly. If the control flow operations have NO side-effects, you can safely ignore this function andafter_control_flow()
.- Parameters:
target_functions – Functions whose non-local variables are to be collected.
keep_vars – See
torch.nn.Module.state_dict(..., keep_vars)
. Defaults to True.
- Returns:
A tuple containing the merged state dictionary, a list of state keys, and a list of non-local variable names.
- Raises:
AssertionError – If not all target functions are local, global, or this class member functions
Warning
The non-local variables collected here can ONLY be used as read-only ones. In-place modifications to these variables may not raise any error and silently produce incorrect results.
Usage
@jit_class def ExampleModule(ModuleBase): # define the normal `__init__` and `test` functions, etc. @trace_impl(test) def trace_test(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: self.q = self.q + 1 local_q = self.q * 2 def false_branch(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # nonlocal local_q ## These two lines may silently produce incorrect results # local_q *= 1.5 return x * y * local_q # However, using read-only nonlocals is allowed state, keys = self.prepare_control_flow(self.true_branch, false_branch) if not hasattr(self, "_switch_"): self._switch_ = TracingSwitch(self.true_branch, false_branch) state, ret = self._switch_.switch(state, (x.flatten()[0] > 0).to(dtype=torch.int), x, y) self.after_control_flow(state, *keys) return ret
- after_control_flow(state: Dict[str, torch.Tensor], state_keys: List[str], nonlocal_keys: List[str]) Dict[str, torch.Tensor] [source]#
Restores the module state to the one before
prepare_control_flow
from the givenstate
and returns the non-local variables collected inprepare_control_flow
.This function is used alongside with
prepare_control_flow()
to enable your control flow operations (utils.control_flow.*
) deal with side-effects correctly. If the control flow operations have NO side-effects, you can safely ignore this function andprepare_control_flow()
.- Parameters:
state – The state dictionary to restore the module state from.
state_keys – The keys of the state dictionary that represent the module state.
nonlocal_keys – The keys of the state dictionary that represent the non-local variables.
- Returns:
The non-local variables dictionary collected in
prepare_control_flow
.
Usage
See
prepare_control_flow()
.
- load_state_dict(state_dict: Mapping[str, torch.Tensor], copy: bool = False, **kwargs)[source]#
Copy parameters and buffers from state_dict into this module and its descendants. Overwrites
torch.nn.Module.load_state_dict
.- Parameters:
state_dict – A dict containing parameters and buffers used to update this module. See
torch.nn.Module.load_state_dict
.copy – Use the original
torch.nn.Module.load_state_dict
to copy thestate_dict
to current state (copy=True
) or use this implementation that assigns the values of this module to the ones in thestate_dict
(copy=False
). Defaults to False.**kwargs – The original arguments of
torch.nn.Module.load_state_dict
. Ignored ifcopy=False
.
- Returns:
If
copy=True
, returns the return oftorch.nn.Module.load_state_dict
; otherwise, no return.
- add_mutable(name: str, value: Union[torch.Tensor | torch.nn.Module, Sequence[torch.Tensor | torch.nn.Module], Dict[str, torch.Tensor | torch.nn.Module]]) None [source]#
Define a mutable value in this module that can be accessed via
self.[name]
and modified in-place.- Parameters:
name – The mutable value’s name.
value – The mutable value, can be a tuple, list, dictionary of a
torch.Tensor
.
- Raises:
NotImplementedError – If the mutable value’s type is not supported yet.
AssertionError – If the
name
is invalid.
- to(*args, **kwargs) evox.core.module.ModuleBase [source]#
- __getitem__(key: Union[int, slice, str]) Union[torch.Tensor, List[torch.Tensor]] [source]#
Get the mutable value(s) stored in this list-like module.
- Parameters:
key – The key used to index mutable value(s).
- Raises:
IndexError – If
key
is out of range.TypeError – If
key
is of wrong type.
- Returns:
The indexed mutable value(s).
- evox.core.module._using_state: contextvars.ContextVar[bool]#
‘ContextVar(…)’
- evox.core.module._trace_caching_state: contextvars.ContextVar[bool]#
‘ContextVar(…)’
- evox.core.module.use_state_context(new_use_state: bool = True)[source]#
A context manager to set the value of
using_state
temporarily.When entering the context, the value of
using_state
is set tonew_use_state
and a token is obtained. When exiting the context, the value ofusing_state
is reset to its previous value.- Parameters:
new_use_state – The new value of
using_state
. Defaults to True.
Examples:
>>> with use_state_context(True): ... assert is_using_state() >>> assert not is_using_state()
- evox.core.module.trace_caching_state_context(new_trace_caching_state: bool = True)[source]#
A context manager to set the value of
trace_caching_state
temporarily.When entering the context, the value of
trace_caching_state
is set tonew_trace_caching_state
and a token is obtained. When exiting the context, the value oftrace_caching_state
is reset to its previous value.- Parameters:
new_trace_caching_state – The new value of
trace_caching_state
. Defaults to True.
Examples:
>>> with trace_caching_state_context(True): ... assert is_trace_caching_state() >>> assert not is_trace_caching_state()
- evox.core.module.is_using_state() bool [source]#
Get the current state of the
using_state
.- Returns:
The current state of the
using_state
.
- evox.core.module.is_trace_caching_state() bool [source]#
Get the current state of the
trace_caching_state
.- Returns:
The current state of the
trace_caching_state
.
- evox.core.module.tracing_or_using_state()[source]#
Check if we are currently JIT tracing (inside a
torch.jit.trace
), in ause_state_context
, or in atrace_caching_state
.- Returns:
True if either condition is true, False otherwise.
- evox.core.module._SUBMODULE_PREFIX#
‘_submodule’
- class evox.core.module._WrapClassBase(inner: evox.core.module.ModuleBase)[source]#
Initialization
- evox.core.module._USE_STATE_NAME#
‘use_state’
- evox.core.module._STATE_ARG_NAME#
‘state’
- class evox.core.module.UseStateFunc[source]#
Bases:
typing.Protocol
- is_empty_state: bool#
None
- init_state(clone: bool = True) Dict[str, torch.Tensor] [source]#
Get the cloned state of the closures of the function when it is wrapped by
use_state
.- Parameters:
clone – Whether to clone the original state or not. Defaults to True.
- Returns:
The cloned state of the closures.
- evox.core.module._EMPTY_NAME#
‘empty’
- evox.core.module.use_state(func: Callable[[], Callable] | Callable, is_generator: bool = True) evox.core.module.UseStateFunc [source]#
Transform the given stateful function (which in-place alters
nn.Module
s) to a pure-functional version that receives an additionalstate
parameter (of typeDict[str, torch.Tensor]
) and returns the altered state additionally.- Parameters:
func – The stateful function to be transformed or its generator function.
is_generator – Whether
func
is a function or a function generator (e.g. a lambda that returns the stateful function). Defaults toTrue
.
- Returns:
The transformed pure-functional version of
func
. It contains ainit_state() -> state
attribute that returns the copy of the current state thatfunc
uses and can be used as example inputs of the additionalstate
parameter. It also contains aset_state(state)
attribute to set the global state to the given one (of course not JIT-compatible).
Notice
Since PyTorch cannot JIT or vectorized-map a function with empty dictionary, list, or tuple as its input, this function transforms the given function to a function WITHOUT the additional
state
parameter (of typeDict[str, torch.Tensor]
) and does NOT return the altered state additionally.Usage:
@jit_class class Example(ModuleBase): def __init__(self, threshold=0.5): super().__init__() self.threshold = threshold self.sub_mod = nn.Module() self.sub_mod.buf = nn.Buffer(torch.zeros(())) def h(self, q: torch.Tensor) -> torch.Tensor: if q.flatten()[0] > self.threshold: x = torch.sin(q) else: x = torch.tan(q) x += self.g(x).abs() x *= x.shape[1] self.sub_mod.buf = x.sum() return x @trace_impl(h) def th(self, q: torch.Tensor) -> torch.Tensor: x += self.g(x).abs() x *= x.shape[1] self.sub_mod.buf = x.sum() return x def g(self, p: torch.Tensor) -> torch.Tensor: x = torch.cos(p) return x * p.shape[0] t = Test() fn = use_state(lambda: t.h, is_generator=True) jit_fn = jit(fn, trace=True, lazy=True) results = jit_fn(fn.init_state(), torch.rand(10, 1) # print results, may be print(results) # ({"self.sub_mod.buf": torch.Tensor(5.6)}, torch.Tensor([[0.56], ...])) # IN-PLACE update all relevant variables using the given state, which is the variable `t` here. fn.set_state(results[0])
- evox.core.module._TORCHSCRIPT_MODIFIER#
‘_torchscript_modifier’
- evox.core.module._TRACE_WRAP_NAME#
‘trace_wrapped’
- evox.core.module.T#
‘TypeVar(…)’
- evox.core.module.trace_impl(target: Callable)[source]#
A helper function used to annotate that the wrapped method shall be treated as a trace-JIT-time proxy of the given
target
method.Can ONLY be used inside a
jit_class
for a member method.- Parameters:
target – The target method invoked when not tracing JIT.
- Returns:
The wrapping function to annotate the member method.
Notice
The target function and the annotated function MUST have same input/output signatures (e.g. number of arguments and types); otherwise, the resulting behavior is UNDEFINED.
If the annotated function are to be
vmap
, it cannot contain any in-place operations toself
since such operations are not well-defined and cannot be compiled.
Usage:
See
use_state
.
- evox.core.module._VMAP_WRAP_NAME#
‘vmap_wrapped’
- evox.core.module.vmap_impl(target: Callable)[source]#
A helper function used to annotate that the wrapped method shall be treated as a vmap-JIT-time proxy of the given
target
method.Can ONLY be used inside a
jit_class
for a member method.- Parameters:
target – The target method invoked when not tracing JIT.
- Returns:
The wrapping function to annotate the member method.
Notice
The target function and the annotated function MUST have same input/output signatures (e.g. number of arguments and types); otherwise, the resulting behavior is UNDEFINED.
If the annotated function are to be
vmap
, it cannot contain any in-place operations toself
since such operations are not well-defined and cannot be compiled.
Usage:
See
use_state
.
- evox.core.module.ClassT#
‘TypeVar(…)’
- evox.core.module._BASE_NAME#
‘base’
- evox.core.module.jit_class(cls: evox.core.module.ClassT, trace: bool = False) evox.core.module.ClassT [source]#
A helper function used to JIT script (
torch.jit.script
) or trace (torch.jit.trace_module
) all member methods of classcls
.- Parameters:
cls – The original class whose member methods are to be lazy JIT.
trace – Whether to trace the module or to script the module. Default to False.
Returns: The wrapped class.
Notice
In many cases, it is not necessary to wrap your custom algorithms or problems with
jit_class
, the workflow(s) will do the trick for you.With
trace=True
, all the member functions are effectively modified to returnself
additionally since side-effects cannot be traced. If you want to preserve the side effects, please settrace=False
and use theuse_state
function to wrap the member method to generate pure-functionalSimilarly, all module-wide operations like
self.to(...)
can only returns the unwrapped module, which may not be desired. Since most of them are in-place operations, a simplemodule.to(...)
can be used instead ofmodule = module.to(...)
.
Usage:
@jit_class class Example(ModuleBase): # magic methods are ignored in JIT def __init__(self, threshold = 0.5): super().__init__() self.threshold = threshold # `torch.jit.export` is automatically added to this member method def h(self, q: torch.Tensor) -> torch.Tensor: if q.flatten()[0] > self.threshold: x = torch.sin(q) else: x = torch.tan(q) return x * x.shape[1] exp = Example(0.75) print(exp.h(torch.rand(10, 2))) # equivalent to exp = torch.jit.trace_module(Example(0.75)) print(exp.h(torch.rand(10, 2)))