evox.core.jit_util
#
Module Contents#
Classes#
Functions#
Vectorized map the given function to its mapped version, see |
|
Just-In-Time (JIT) compile the given |
Data#
API#
- class evox.core.jit_util.MappedUseStateFunc[source]#
Bases:
typing.Protocol
- init_state(batch_size: int | None = None, expand: bool = True) Dict[str, torch.Tensor] [source]#
Initialize the state of the mapped function.
- Parameters:
batch_size – The batch size of the state. If
None
, the batch size of the state is indicated by VMAP_DIM_CONST. Defaults toNone
.expand – Whether to
torch.expand
ortorch.repeat
the state tensors to the given batch size.
- Returns:
The initialized state, with the same keys as the state of the original function.
- evox.core.jit_util.T#
‘TypeVar(…)’
- evox.core.jit_util.vmap(func: evox.core.jit_util.T | evox.core.module.UseStateFunc, in_dims: Optional[int | Tuple[int, ...]] = 0, out_dims: Optional[int | Tuple[int, ...]] = 0, trace: bool = True, example_ndim: Tuple[int | None] | int = 1, example_shapes: Optional[Tuple[Tuple[int] | Any] | Tuple[int | Any]] = None, example_inputs: Optional[Tuple[torch.Tensor | Any]] = None, strict: bool = False, check_trace: bool = False, batched_state: Dict[str, torch.Tensor] | None = None, VMAP_DIM_CONST: int = 13) evox.core.jit_util.T | evox.core.jit_util.MappedUseStateFunc [source]#
Vectorized map the given function to its mapped version, see
torch.vmap
for more information.- Parameters:
func – The function to be mapped. See
torch.vmap
.in_dims – The inputs’ batch dimensions. See
torch.vmap
. Defaults to 0.out_dims – The outputs’ batch dimensions. See
torch.vmap
. Defaults to 0.trace – Whether to trace the mapped function with
torch.jit.trace
or simply usetorch.vmap. NOTICE: if
trace=False`, all of the following inputs related to tracing will be ignored.example_ndim – The
ndim
of the expected inputs of the batched function; thus, it must be at least 1. Giving a single integer means samendim
for all inputs. Defaults to 1.example_shapes – The . Defaults to None.
example_inputs – description. Defaults to None.
strict – Strictly check the inputs or not. See
torch.jit.trace
. Defaults to False.check_trace – Check the traced function or not. See
torch.jit.trace
. Defaults to False.batched_state – The optional batched current state for a
use_state
wrapped function. If None, a new batched state will be returned for each call ofinit_state(None)
. Ignored whenfunc
is not wrapped byuse_state
. Defaults to None.VMAP_DIM_CONST – When tracing, the example inputs may be broadcasted with additional dimension(s) of size
VMAP_DIM_CONST
. Defaults to 13.
- Raises:
NotImplementedError – If the function argument types are not supported
- Returns:
The “batched” (vectorized mapped) version of
func
. If the givenfunc
is wrapped byuse_state
, the returned function will have ainit_state(batch_size: int) -> batched_state
orinit_state(None) -> batched_state
.
- evox.core.jit_util.jit(func: evox.core.jit_util.T | evox.core.module.UseStateFunc | evox.core.jit_util.MappedUseStateFunc, trace: bool = False, lazy: bool = False, example_inputs: Optional[Tuple | Dict | Tuple[Tuple, Dict]] = None, strict: bool = False, check_trace: bool = False, is_generator: bool = False, no_cache: bool = False, return_dummy_output: bool = False, debug_manual_seed: int | None = None) evox.core.jit_util.T | evox.core.module.UseStateFunc | evox.core.jit_util.MappedUseStateFunc [source]#
Just-In-Time (JIT) compile the given
func
viatorch.jit.trace
(trace=True
) ortorch.jit.script
(trace=False
).This function wrapper effectively deals with nested JIT and vector map (
vmap
) expressions likejit(func1)
->vmap
->jit(func2)
, preventing possible errors.Notice
1. With `trace=True`, `torch.jit.trace` cannot use SAME example input arguments for function of DIFFERENT parameters, e.g., you cannot pass `tensor_a, tensor_a` to `torch.jit.trace`d version of `f(x: torch.Tensor, y: torch.Tensor)`. 2. With `trace=False`, `torch.jit.script` cannot contain `vmap` expressions directly, please wrap them with `jit(..., trace=True)` or `torch.jit.trace`.
- Parameters:
func – The target function to be JIT
trace – Whether using
torch.jit.trace
ortorch.jit.script
to JIT. Defaults to False.lazy – Whether JIT lazily or immediately. Defaults to False.
example_inputs – When
lazy=False
, the example inputs must be provided immediately, otherwise ignored. Can be only positional arguments (a tuple), only keyword arguments (a dict), or a tuple of positional arguments and keyword arguments (a tuple of tuple and dict). Defaults to None.strict – Strictly check the inputs or not. See
torch.jit.trace
. Defaults to False.check_trace – Check the traced function or not. See
torch.jit.trace
. Defaults to False.is_generator – Whether
func
is a generator or not. Defaults to False.no_cache – Whether to use
torch.jit.trace
directly (no_cache=True
) or run the function to make it cache internals whenlazy=False
. Defaults to False. Has no effect whentrace=False
. This value must be set toFalse
if the function contains a instant call totorch.jit.trace
which will be used inside atorch.jit.script
so that the JIT traced result shall be cached.return_dummy_output – Whether to return the dummy output of
func
as the second output or not. Defaults to False. Has no effect whentrace=False
orlazy=True
orno_cache=True
.debug_manual_seed – The manual seed to be set before each running of the function. Defaults to None. Has no effect when
trace=False
. None means no manual seed will be set. Notice that any value other than None changes the GLOBAL random seed.
- Returns:
The JIT version of
func