evox.core.jit_util#

Module Contents#

Classes#

Functions#

vmap

Vectorized map the given function to its mapped version, see torch.vmap for more information.

_clone_inputs

_form_positional_inputs

jit

Just-In-Time (JIT) compile the given func via torch.jit.trace (trace=True) or torch.jit.script (trace=False).

Data#

T

API#

class evox.core.jit_util.MappedUseStateFunc[source]#

Bases: typing.Protocol

init_state(batch_size: int | None = None, expand: bool = True) Dict[str, torch.Tensor][source]#

Initialize the state of the mapped function.

Parameters:
  • batch_size – The batch size of the state. If None, the batch size of the state is indicated by VMAP_DIM_CONST. Defaults to None.

  • expand – Whether to torch.expand or torch.repeat the state tensors to the given batch size.

Returns:

The initialized state, with the same keys as the state of the original function.

__call__(state: Dict[str, torch.Tensor], *args, **kwargs) Dict[str, torch.Tensor] | Tuple[Dict[str, torch.Tensor], Any][source]#
evox.core.jit_util.T#

‘TypeVar(…)’

evox.core.jit_util.vmap(func: evox.core.jit_util.T | evox.core.module.UseStateFunc, in_dims: Optional[int | Tuple[int, ...]] = 0, out_dims: Optional[int | Tuple[int, ...]] = 0, trace: bool = True, example_ndim: Tuple[int | None] | int = 1, example_shapes: Optional[Tuple[Tuple[int] | Any] | Tuple[int | Any]] = None, example_inputs: Optional[Tuple[torch.Tensor | Any]] = None, strict: bool = False, check_trace: bool = False, batched_state: Dict[str, torch.Tensor] | None = None, VMAP_DIM_CONST: int = 13) evox.core.jit_util.T | evox.core.jit_util.MappedUseStateFunc[source]#

Vectorized map the given function to its mapped version, see torch.vmap for more information.

Parameters:
  • func – The function to be mapped. See torch.vmap.

  • in_dims – The inputs’ batch dimensions. See torch.vmap. Defaults to 0.

  • out_dims – The outputs’ batch dimensions. See torch.vmap. Defaults to 0.

  • trace – Whether to trace the mapped function with torch.jit.trace or simply use torch.vmap. NOTICE: if trace=False`, all of the following inputs related to tracing will be ignored.

  • example_ndim – The ndim of the expected inputs of the batched function; thus, it must be at least 1. Giving a single integer means same ndim for all inputs. Defaults to 1.

  • example_shapes – The . Defaults to None.

  • example_inputsdescription. Defaults to None.

  • strict – Strictly check the inputs or not. See torch.jit.trace. Defaults to False.

  • check_trace – Check the traced function or not. See torch.jit.trace. Defaults to False.

  • batched_state – The optional batched current state for a use_state wrapped function. If None, a new batched state will be returned for each call of init_state(None). Ignored when func is not wrapped by use_state. Defaults to None.

  • VMAP_DIM_CONST – When tracing, the example inputs may be broadcasted with additional dimension(s) of size VMAP_DIM_CONST. Defaults to 13.

Raises:

NotImplementedError – If the function argument types are not supported

Returns:

The “batched” (vectorized mapped) version of func. If the given func is wrapped by use_state, the returned function will have a init_state(batch_size: int) -> batched_state or init_state(None) -> batched_state.

evox.core.jit_util._clone_inputs(inputs)[source]#
evox.core.jit_util._form_positional_inputs(func_args, args, kwargs, is_empty_state=False)[source]#
evox.core.jit_util.jit(func: evox.core.jit_util.T | evox.core.module.UseStateFunc | evox.core.jit_util.MappedUseStateFunc, trace: bool = False, lazy: bool = False, example_inputs: Optional[Tuple | Dict | Tuple[Tuple, Dict]] = None, strict: bool = False, check_trace: bool = False, is_generator: bool = False, no_cache: bool = False, return_dummy_output: bool = False, debug_manual_seed: int | None = None) evox.core.jit_util.T | evox.core.module.UseStateFunc | evox.core.jit_util.MappedUseStateFunc[source]#

Just-In-Time (JIT) compile the given func via torch.jit.trace (trace=True) or torch.jit.script (trace=False).

This function wrapper effectively deals with nested JIT and vector map (vmap) expressions like jit(func1) -> vmap -> jit(func2), preventing possible errors.

Notice

1. With `trace=True`, `torch.jit.trace` cannot use SAME example input arguments for function of DIFFERENT parameters,
e.g., you cannot pass `tensor_a, tensor_a` to `torch.jit.trace`d version of `f(x: torch.Tensor, y: torch.Tensor)`.
2. With `trace=False`, `torch.jit.script` cannot contain `vmap` expressions directly, please wrap them with `jit(..., trace=True)` or `torch.jit.trace`.
Parameters:
  • func – The target function to be JIT

  • trace – Whether using torch.jit.trace or torch.jit.script to JIT. Defaults to False.

  • lazy – Whether JIT lazily or immediately. Defaults to False.

  • example_inputs – When lazy=False, the example inputs must be provided immediately, otherwise ignored. Can be only positional arguments (a tuple), only keyword arguments (a dict), or a tuple of positional arguments and keyword arguments (a tuple of tuple and dict). Defaults to None.

  • strict – Strictly check the inputs or not. See torch.jit.trace. Defaults to False.

  • check_trace – Check the traced function or not. See torch.jit.trace. Defaults to False.

  • is_generator – Whether func is a generator or not. Defaults to False.

  • no_cache – Whether to use torch.jit.trace directly (no_cache=True) or run the function to make it cache internals when lazy=False. Defaults to False. Has no effect when trace=False. This value must be set to False if the function contains a instant call to torch.jit.trace which will be used inside a torch.jit.script so that the JIT traced result shall be cached.

  • return_dummy_output – Whether to return the dummy output of func as the second output or not. Defaults to False. Has no effect when trace=False or lazy=True or no_cache=True.

  • debug_manual_seed – The manual seed to be set before each running of the function. Defaults to None. Has no effect when trace=False. None means no manual seed will be set. Notice that any value other than None changes the GLOBAL random seed.

Returns:

The JIT version of func