evox.core._vmap_fix#

模块内容#

函数#

_set_func_id

unwrap_batch_tensor

将批量张量展开为其原始张量及批量维度/大小。

wrap_batch_tensor

将原始张量封装到其批量形式中,给定批量维度。

_get_batched_size

get_vmap_batch_sizes

vmap_increment_nesting

_flat_vmap

batched_random

生成一个包含随机值的批量张量。

batched_random_like

生成一个与给定张量形状相同的随机值批处理张量。

_batch_size

_batch_rand

_batch_randn

_batch_randint

_batch_randperm

_batch_rand_like

_batch_randn_like

_batch_randint_like

_batch_getitem

_batch_setitem

use_batch_fixing

align_vmap_tensor

将张量与当前批处理张量的批处理维度对齐。

_debug_print

debug_print

打印一个格式化字符串,包含一个用于调试的定位张量,在 JIT 跟踪的函数中实时使用。

数据#

API#

evox.core._vmap_fix._set_func_id(new_func, old_func)[源代码]#
evox.core._vmap_fix.unwrap_batch_tensor(tensor: torch.Tensor) Tuple[torch.Tensor, Tuple[int, ...], Tuple[int, ...]][源代码]#

将批量张量展开为其原始张量及批量维度/大小。

参数:

tensor -- 要解包的批处理张量。

返回:

原始张量、批次维度和批次大小的元组。

evox.core._vmap_fix.wrap_batch_tensor(tensor: torch.Tensor, in_dims: int | Tuple[int, ...]) torch.Tensor[源代码]#

将原始张量封装到其批量形式中,给定批量维度。

参数:
  • tensor -- 要包装的原始张量。

  • in_dims -- 批量维度(s)。

返回:

批量张量。

evox.core._vmap_fix._get_batched_size(in_dim: int | Tuple[int, ...], original: torch.Tensor) int | Tuple[int, ...] | None[源代码]#
evox.core._vmap_fix._vmap_batch_sizes: contextvars.ContextVar[List[int]]#

'ContextVar(...)'

evox.core._vmap_fix.get_vmap_batch_sizes()[源代码]#
evox.core._vmap_fix.vmap_increment_nesting(batch_size, randomness)[源代码]#
evox.core._vmap_fix._flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)[源代码]#
evox.core._vmap_fix.batched_random(rand_func: Callable, *size: Tuple[int | torch.SymInt], **kwargs) torch.Tensor[源代码]#

生成一个包含随机值的批量张量。

给定一个随机函数(例如 torch.randntorch.rand 等)及其大小参数,该函数通过将给定函数应用于扩展当前 vmap 批大小的大小,生成一个随机值的批量张量。

参数:
  • rand_func -- 生成随机值的张量的函数。

  • *size -- 给定函数的大小参数。

  • **kwargs -- 给定函数的关键字参数。

返回:

随机值的批量张量。

Usage:

rand1 = batched_random(torch.rand, 2, 3, device=device)
rand2 = batched_random(torch.randn, 4, device=device, dtype=torch.float32)
rand3 = batched_random(torch.randint, 5, 6, low=0, high=10, device=device, dtype=torch.float32)
evox.core._vmap_fix.batched_random_like(rand_func: Callable, like_tensor: torch.Tensor, **kwargs) torch.Tensor[源代码]#

生成一个与给定张量形状相同的随机值批处理张量。

给定一个随机函数(例如 torch.randn_liketorch.rand_like 等)和一个张量, 此函数通过将给定函数应用于当前 vmap 批次大小扩展的张量,生成一个批量的随机值张量。

参数:
  • rand_func -- 生成随机值的张量的函数。

  • like_tensor -- 生成随机值的张量。

  • **kwargs -- 给定函数的关键字参数。

返回:

随机值的批量张量。

evox.core._vmap_fix._original_size#

没有可翻译的文本。

evox.core._vmap_fix._original_rand#

没有可翻译的文本。

evox.core._vmap_fix._original_randn#

没有可翻译的文本。

evox.core._vmap_fix._original_randint#

没有可翻译的文本。

evox.core._vmap_fix._original_randperm#

没有可翻译的文本。

evox.core._vmap_fix._original_rand_like#

没有可翻译的文本。

evox.core._vmap_fix._original_randn_like#

没有可翻译的文本。

evox.core._vmap_fix._original_randint_like#

没有可翻译的文本。

evox.core._vmap_fix._original_get_item#

没有可翻译的文本。

evox.core._vmap_fix._original_set_item#

没有可翻译的文本。

evox.core._vmap_fix._batch_size(tensor: torch.Tensor, dim: int | None = None)[源代码]#
evox.core._vmap_fix._batch_rand(*size, **kwargs)[源代码]#
evox.core._vmap_fix._batch_randn(*size, **kwargs)[源代码]#
evox.core._vmap_fix._batch_randint(low=None, high=None, size=None, **kwargs)[源代码]#
evox.core._vmap_fix._batch_randperm(n, **kwargs)[源代码]#
evox.core._vmap_fix._batch_rand_like(like_tensor, **kwargs)[源代码]#
evox.core._vmap_fix._batch_randn_like(like_tensor, **kwargs)[源代码]#
evox.core._vmap_fix._batch_randint_like(like_tensor, **kwargs)[源代码]#
evox.core._vmap_fix._batch_getitem(tensor: torch.Tensor, indices, dim=0)[源代码]#
evox.core._vmap_fix._batch_setitem(tensor: torch.Tensor, indices, values, dim=0)[源代码]#
evox.core._vmap_fix._batch_fixing: contextvars.ContextVar[bool]#

'ContextVar(...)'

evox.core._vmap_fix.use_batch_fixing(new_batch_fixing: bool = True)[源代码]#
evox.core._vmap_fix.align_vmap_tensor(value: Any, current_value: Any | None) torch.Tensor[源代码]#

将张量与当前批处理张量的批处理维度对齐。

该函数将输入张量 value 调整为匹配 current_value 的批次维度,后者被假设为批次张量。如果 value 已经是批次张量或者 current_value 不是批次张量,它会返回 value 不变。

参数:
  • value -- 要对齐的张量。如果不是一个 torch.Tensor,则保持不变返回。

  • current_value -- 引用批量张量。如果是 None 或者不是批量张量,则返回的值保持不变。

返回:

输入值与当前值的批次维度对齐(如果适用)。

evox.core._vmap_fix._debug_print(format: str, arg: torch.Tensor) torch.Tensor[源代码]#
evox.core._vmap_fix.debug_print(format: str, arg: torch.Tensor) torch.Tensor[源代码]#

打印一个格式化字符串,包含一个用于调试的定位张量,在 JIT 跟踪的函数中实时使用。

在向量化映射时,它会展开批量张量以打印底层值。否则,该函数的行为类似于 format.format(*args, **kwargs)

参数:
  • format -- 字符串格式。

  • arg -- 位置张量。

返回:

未改变的张量。