evox.core.jit_util
#
模块内容#
类#
函数#
向量化映射将给定的函数映射到其映射版本,更多信息请参见 |
|
通过 |
数据#
API#
- class evox.core.jit_util.MappedUseStateFunc[源代码]#
Bases:
typing.Protocol
- evox.core.jit_util.T#
TypeVar(...)
- evox.core.jit_util.vmap(func: evox.core.jit_util.T | evox.core.module.UseStateFunc, in_dims: Optional[int | Tuple[int, ...]] = 0, out_dims: Optional[int | Tuple[int, ...]] = 0, trace: bool = True, example_ndim: Tuple[int | None] | int = 1, example_shapes: Optional[Tuple[Tuple[int] | Any] | Tuple[int | Any]] = None, example_inputs: Optional[Tuple[torch.Tensor | Any]] = None, strict: bool = False, check_trace: bool = False, batched_state: Dict[str, torch.Tensor] | None = None, VMAP_DIM_CONST: int = 13) evox.core.jit_util.T | evox.core.jit_util.MappedUseStateFunc [源代码]#
向量化映射将给定的函数映射到其映射版本,更多信息请参见
torch.vmap
。- 参数:
func -- 要映射的函数。请参见 torch.vmap。
in_dims -- 输入的批次维度。请参阅 torch.vmap。默认为 0。
out_dims -- 输出的批次维度。请参见 torch.vmap。默认为 0。
trace -- 是否使用
torch.jit.trace
追踪已映射的函数,还是简单地使用torch.vmap
。注意:如果trace=False
,所有与追踪相关的输入将被忽略。example_ndim -- 批处理函数期望输入的维度(ndim);因此,它必须至少为 1。给定一个整数意味着所有输入具有相同的维度。默认值为 1。
example_shapes -- 默认为 None。
example_inputs -- 描述。默认为 None。
strict -- 严格检查输入与否。参见 torch.jit.trace。默认为 False。
check_trace -- 检查是否跟踪了该函数。参见 torch.jit.trace。默认为 False。
batched_state -- 可选的批处理当前状态,用于被 use_state 包裹的函数。如果为 None,则每次调用 init_state(None) 时将返回一个新的批处理状态。当 func 没有被 use_state 包裹时,该选项被忽略。默认值为 None。
VMAP_DIM_CONST -- 依赖于 trace 时,示例输入可能会添加一个或多个大小为 VMAP_DIM_CONST 的维度。默认值:13。
- 抛出:
NotImplementedError -- 如果函数参数类型不支持
- 返回:
“批处理”(向量化映射)版本的 func。如果给定的 func 被 use_state 包裹,返回的函数将具有 init_state(batch_size: int) -> batched_state 或 init_state(None) -> batched_state。
- evox.core.jit_util.jit(func: evox.core.jit_util.T | evox.core.module.UseStateFunc | evox.core.jit_util.MappedUseStateFunc, trace: bool = False, lazy: bool = False, example_inputs: Optional[Tuple | Dict | Tuple[Tuple, Dict]] = None, strict: bool = False, check_trace: bool = False, is_generator: bool = False, no_cache: bool = False, return_dummy_output: bool = False, debug_manual_seed: int | None = None) evox.core.jit_util.T | evox.core.module.UseStateFunc | evox.core.jit_util.MappedUseStateFunc [源代码]#
通过
torch.jit.trace
(trace=True
) 或torch.jit.script
(trace=False
) 对给定的func
进行即时编译 (JIT)。该函数包装器有效处理嵌套的 JIT 和向量映射 (
vmap
) 表达式,如jit(func1)
->vmap
->jit(func2)
,从而防止可能出现的错误。Notice
1. With `trace=True`, `torch.jit.trace` cannot use SAME example input arguments for function of DIFFERENT parameters, e.g., you cannot pass `tensor_a, tensor_a` to `torch.jit.trace`d version of `f(x: torch.Tensor, y: torch.Tensor)`. 2. With `trace=False`, `torch.jit.script` cannot contain `vmap` expressions directly, please wrap them with `jit(..., trace=True)` or `torch.jit.trace`.
- 参数:
func -- 要进行 JIT 的目标函数
trace -- 是否使用 torch.jit.trace 或 torch.jit.script 进行 JIT。默认为 False。
lazy -- 是选择懒加载还是立即执行 JIT。默认值为 False。
example_inputs -- 当 lazy=False 时,示例输入必须立即提供,否则将被忽略。可以是仅位置参数(一个 tuple),仅关键字参数(一个 dict),或位置参数和关键字参数的组合(一个 tuple 和 dict 的 tuple)。Defaults to None。
strict -- 严格检查输入与否。参见 torch.jit.trace。默认为 False。
check_trace -- 检查是否跟踪了该函数。参见 torch.jit.trace。默认为 False。
is_generator -- 无论 func 是否是生成器。默认为 False。
no_cache -- 是否直接使用 torch.jit.trace (no_cache=True)或在 lazy=False 时运行函数以使其缓存内部内容。默认为 False。当 trace=False 时没有效果。如果函数包含对 torch.jit.trace 的即时调用,并将在 torch.jit.script 内部使用,则该值必须设置为 False,以便 JIT 跟踪的结果将被缓存。
return_dummy_output -- 是否将
func
的 dummy 输出作为第二个输出返回默认值为False。如果trace=False
或lazy=True
或no_cache=True
则无效。debug_manual_seed -- 在每次运行该函数之前要设置的手动种子。默认为 None。当 trace=False 时没有效果。None 表示不设置手动种子。注意,任何其他值都将改变全局随机种子。
- 返回:
func 的 JIT 版本