evox.core._vmap_fix
#
模块内容#
函数#
将批量张量展开为其原始张量及批量维度/大小。 |
|
将原始张量封装到其批量形式中,给定批量维度。 |
|
生成一个包含随机值的批量张量。 |
|
生成一个与给定张量形状相同的随机值批处理张量。 |
|
将张量与当前批处理张量的批处理维度对齐。 |
|
打印一个格式化字符串,包含一个用于调试的定位张量,在 JIT 跟踪的函数中实时使用。 |
数据#
API#
- evox.core._vmap_fix.unwrap_batch_tensor(tensor: torch.Tensor) Tuple[torch.Tensor, Tuple[int, ...], Tuple[int, ...]] [源代码]#
将批量张量展开为其原始张量及批量维度/大小。
- 参数:
tensor -- 要解包的批处理张量。
- 返回:
原始张量、批次维度和批次大小的元组。
- evox.core._vmap_fix.wrap_batch_tensor(tensor: torch.Tensor, in_dims: int | Tuple[int, ...]) torch.Tensor [源代码]#
将原始张量封装到其批量形式中,给定批量维度。
- 参数:
tensor -- 要包装的原始张量。
in_dims -- 批量维度(s)。
- 返回:
批量张量。
- evox.core._vmap_fix._get_batched_size(in_dim: int | Tuple[int, ...], original: torch.Tensor) int | Tuple[int, ...] | None [源代码]#
- evox.core._vmap_fix._vmap_batch_sizes: contextvars.ContextVar[List[int]]#
'ContextVar(...)'
- evox.core._vmap_fix._flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)[源代码]#
- evox.core._vmap_fix.batched_random(rand_func: Callable, *size: Tuple[int | torch.SymInt], **kwargs) torch.Tensor [源代码]#
生成一个包含随机值的批量张量。
给定一个随机函数(例如
torch.randn
、torch.rand
等)及其大小参数,该函数通过将给定函数应用于扩展当前 vmap 批大小的大小,生成一个随机值的批量张量。- 参数:
rand_func -- 生成随机值的张量的函数。
*size -- 给定函数的大小参数。
**kwargs -- 给定函数的关键字参数。
- 返回:
随机值的批量张量。
Usage:
rand1 = batched_random(torch.rand, 2, 3, device=device) rand2 = batched_random(torch.randn, 4, device=device, dtype=torch.float32) rand3 = batched_random(torch.randint, 5, 6, low=0, high=10, device=device, dtype=torch.float32)
- evox.core._vmap_fix.batched_random_like(rand_func: Callable, like_tensor: torch.Tensor, **kwargs) torch.Tensor [源代码]#
生成一个与给定张量形状相同的随机值批处理张量。
给定一个随机函数(例如
torch.randn_like
、torch.rand_like
等)和一个张量, 此函数通过将给定函数应用于当前 vmap 批次大小扩展的张量,生成一个批量的随机值张量。- 参数:
rand_func -- 生成随机值的张量的函数。
like_tensor -- 生成随机值的张量。
**kwargs -- 给定函数的关键字参数。
- 返回:
随机值的批量张量。
- evox.core._vmap_fix._original_size#
没有可翻译的文本。
- evox.core._vmap_fix._original_rand#
没有可翻译的文本。
- evox.core._vmap_fix._original_randn#
没有可翻译的文本。
- evox.core._vmap_fix._original_randint#
没有可翻译的文本。
- evox.core._vmap_fix._original_randperm#
没有可翻译的文本。
- evox.core._vmap_fix._original_rand_like#
没有可翻译的文本。
- evox.core._vmap_fix._original_randn_like#
没有可翻译的文本。
- evox.core._vmap_fix._original_randint_like#
没有可翻译的文本。
- evox.core._vmap_fix._original_get_item#
没有可翻译的文本。
- evox.core._vmap_fix._original_set_item#
没有可翻译的文本。
- evox.core._vmap_fix._batch_fixing: contextvars.ContextVar[bool]#
'ContextVar(...)'
- evox.core._vmap_fix.align_vmap_tensor(value: Any, current_value: Any | None) torch.Tensor [源代码]#
将张量与当前批处理张量的批处理维度对齐。
该函数将输入张量
value
调整为匹配current_value
的批次维度,后者被假设为批次张量。如果value
已经是批次张量或者current_value
不是批次张量,它会返回value
不变。- 参数:
value -- 要对齐的张量。如果不是一个 torch.Tensor,则保持不变返回。
current_value -- 引用批量张量。如果是 None 或者不是批量张量,则返回的值保持不变。
- 返回:
输入值与当前值的批次维度对齐(如果适用)。