evox.core.jit_util#

模块内容#

#

函数#

vmap

向量化映射将给定的函数映射到其映射版本,更多信息请参见 torch.vmap

_clone_inputs

_form_positional_inputs

jit

通过 torch.jit.trace (trace=True) 或 torch.jit.script (trace=False) 对给定的 func 进行即时编译 (JIT)。

数据#

T

API#

class evox.core.jit_util.MappedUseStateFunc[源代码]#

Bases: typing.Protocol

init_state(batch_size: int | None = None, expand: bool = True) Dict[str, torch.Tensor][源代码]#

初始化映射函数的状态。

参数:
  • batch_size -- 状态的批处理大小。如果为 None,则状态的批处理大小由 VMAP_DIM_CONST 指示。默认为 None。

  • expand -- 是否将状态张量使用 torch.expand 或 torch.repeat 扩展到给定的批量大小。

返回:

初始化状态,具有与原始函数状态相同的键。

__call__(state: Dict[str, torch.Tensor], *args, **kwargs) Dict[str, torch.Tensor] | Tuple[Dict[str, torch.Tensor], Any][源代码]#
evox.core.jit_util.T#

TypeVar(...)

evox.core.jit_util.vmap(func: evox.core.jit_util.T | evox.core.module.UseStateFunc, in_dims: Optional[int | Tuple[int, ...]] = 0, out_dims: Optional[int | Tuple[int, ...]] = 0, trace: bool = True, example_ndim: Tuple[int | None] | int = 1, example_shapes: Optional[Tuple[Tuple[int] | Any] | Tuple[int | Any]] = None, example_inputs: Optional[Tuple[torch.Tensor | Any]] = None, strict: bool = False, check_trace: bool = False, batched_state: Dict[str, torch.Tensor] | None = None, VMAP_DIM_CONST: int = 13) evox.core.jit_util.T | evox.core.jit_util.MappedUseStateFunc[源代码]#

向量化映射将给定的函数映射到其映射版本,更多信息请参见 torch.vmap

参数:
  • func -- 要映射的函数。请参见 torch.vmap。

  • in_dims -- 输入的批次维度。请参阅 torch.vmap。默认为 0。

  • out_dims -- 输出的批次维度。请参见 torch.vmap。默认为 0。

  • trace -- 是否使用 torch.jit.trace 追踪已映射的函数,还是简单地使用 torch.vmap。注意:如果 trace=False,所有与追踪相关的输入将被忽略。

  • example_ndim -- 批处理函数期望输入的维度(ndim);因此,它必须至少为 1。给定一个整数意味着所有输入具有相同的维度。默认值为 1。

  • example_shapes -- 默认为 None。

  • example_inputs -- 描述。默认为 None。

  • strict -- 严格检查输入与否。参见 torch.jit.trace。默认为 False。

  • check_trace -- 检查是否跟踪了该函数。参见 torch.jit.trace。默认为 False。

  • batched_state -- 可选的批处理当前状态,用于被 use_state 包裹的函数。如果为 None,则每次调用 init_state(None) 时将返回一个新的批处理状态。当 func 没有被 use_state 包裹时,该选项被忽略。默认值为 None。

  • VMAP_DIM_CONST -- 依赖于 trace 时,示例输入可能会添加一个或多个大小为 VMAP_DIM_CONST 的维度。默认值:13。

抛出:

NotImplementedError -- 如果函数参数类型不支持

返回:

“批处理”(向量化映射)版本的 func。如果给定的 func 被 use_state 包裹,返回的函数将具有 init_state(batch_size: int) -> batched_state 或 init_state(None) -> batched_state。

evox.core.jit_util._clone_inputs(inputs)[源代码]#
evox.core.jit_util._form_positional_inputs(func_args, args, kwargs, is_empty_state=False)[源代码]#
evox.core.jit_util.jit(func: evox.core.jit_util.T | evox.core.module.UseStateFunc | evox.core.jit_util.MappedUseStateFunc, trace: bool = False, lazy: bool = False, example_inputs: Optional[Tuple | Dict | Tuple[Tuple, Dict]] = None, strict: bool = False, check_trace: bool = False, is_generator: bool = False, no_cache: bool = False, return_dummy_output: bool = False, debug_manual_seed: int | None = None) evox.core.jit_util.T | evox.core.module.UseStateFunc | evox.core.jit_util.MappedUseStateFunc[源代码]#

通过 torch.jit.trace (trace=True) 或 torch.jit.script (trace=False) 对给定的 func 进行即时编译 (JIT)。

该函数包装器有效处理嵌套的 JIT 和向量映射 (vmap) 表达式,如 jit(func1) -> vmap -> jit(func2),从而防止可能出现的错误。

Notice

1. With `trace=True`, `torch.jit.trace` cannot use SAME example input arguments for function of DIFFERENT parameters,
e.g., you cannot pass `tensor_a, tensor_a` to `torch.jit.trace`d version of `f(x: torch.Tensor, y: torch.Tensor)`.
2. With `trace=False`, `torch.jit.script` cannot contain `vmap` expressions directly, please wrap them with `jit(..., trace=True)` or `torch.jit.trace`.
参数:
  • func -- 要进行 JIT 的目标函数

  • trace -- 是否使用 torch.jit.trace 或 torch.jit.script 进行 JIT。默认为 False。

  • lazy -- 是选择懒加载还是立即执行 JIT。默认值为 False。

  • example_inputs -- 当 lazy=False 时,示例输入必须立即提供,否则将被忽略。可以是仅位置参数(一个 tuple),仅关键字参数(一个 dict),或位置参数和关键字参数的组合(一个 tuple 和 dict 的 tuple)。Defaults to None。

  • strict -- 严格检查输入与否。参见 torch.jit.trace。默认为 False。

  • check_trace -- 检查是否跟踪了该函数。参见 torch.jit.trace。默认为 False。

  • is_generator -- 无论 func 是否是生成器。默认为 False。

  • no_cache -- 是否直接使用 torch.jit.trace (no_cache=True)或在 lazy=False 时运行函数以使其缓存内部内容。默认为 False。当 trace=False 时没有效果。如果函数包含对 torch.jit.trace 的即时调用,并将在 torch.jit.script 内部使用,则该值必须设置为 False,以便 JIT 跟踪的结果将被缓存。

  • return_dummy_output -- 是否将 func 的 dummy 输出作为第二个输出返回默认值为False。如果 trace=Falselazy=Trueno_cache=True 则无效。

  • debug_manual_seed -- 在每次运行该函数之前要设置的手动种子。默认为 None。当 trace=False 时没有效果。None 表示不设置手动种子。注意,任何其他值都将改变全局随机种子。

返回:

func 的 JIT 版本