evox.utils.jit_fix_operator

模块内容

函数

switch

基于标签张量从张量列表中生成张量的逐元素切换选择运算符。

clamp

将输入张量 a 的值限制在给定的下界 (lb) 和上界 (ub) 之间。

clamp_float

将输入张量 a 的浮点值限制在给定的下界 (lb) 和上界 (ub) 之间。

clamp_int

将输入张量 a 的整数值限制在给定的下限 (lb) 和上限 (ub) 之间。

clip

将输入张量 a 的值裁剪到 [0, 1] 范围内。

maximum

两个输入张量 ab 的逐元素最大值。

minimum

两个输入张量 ab 的逐元素最小值。

maximum_float

输入张量 a 和浮点数 b 的逐元素最大值。

minimum_float

输入张量 a 和浮点数 b 的逐元素最小值。

maximum_int

输入张量 a 和整数 b 的逐元素最大值。

minimum_int

输入张量 a 和整数 b 的逐元素最小值。

lexsort

对多个张量进行字典序排序,将每个张量视为一个键。

nanmin

计算张量在指定维度上的最小值,忽略NaN值。

nanmax

计算张量在指定维度上的最大值,忽略NaN值。

randint

随机生成一个整数张量,范围由指定的低值和高值张量决定,类似于 torch.randint。然而,现在低值和高值是张量。

API

evox.utils.jit_fix_operator.switch(label: torch.Tensor, values: List[torch.Tensor]) torch.Tensor[源代码]

基于标签张量从张量列表中生成张量的逐元素切换选择运算符。

参数:
  • label -- 一个张量,包含用于从张量列表中选择的标签。必须可以广播到其余参数的形状。

  • values -- 根据标签从中选择的张量列表。列表中的所有张量必须可以广播到相同的形状。

返回:

一个张量,其中每个元素都是根据标签张量中的相应元素从张量列表中选择的。

evox.utils.jit_fix_operator.clamp(a: torch.Tensor, lb: torch.Tensor, ub: torch.Tensor) torch.Tensor[源代码]

将输入张量 a 的值限制在给定的下界 (lb) 和上界 (ub) 之间。

此函数确保张量 a 的每个元素不小于 lb 的相应元素且不大于 ub 的相应元素。

Notice

这是一个修复函数,适用于torch.clamp,因为它尚未在JIT操作符融合中支持。

Warning

这不是torch.clamp的精确复制,如果albub是浮点张量,可能会遭受数值精度损失。如果需要精确的限制,请使用`torch.clamp

参数:
  • a -- 要被限制的输入张量。

  • lb -- 下界张量。必须可以广播到 a 的形状。

  • ub -- 上界张量。必须可以广播到a的形状。

返回:

张量,其中每个元素都被限制在指定的界限内。

evox.utils.jit_fix_operator.clamp_float(a: torch.Tensor, lb: float, ub: float) torch.Tensor[源代码]

将输入张量 a 的浮点值限制在给定的下界 (lb) 和上界 (ub) 之间。

此函数确保张量 a 的每个元素不小于 lb 且不大于 `ub

Notice

这是一个修复函数,适用于torch.clamp,因为它尚未在JIT操作符融合中支持。

Warning

这不是torch.clamp的精确复制,如果a是浮点张量,可能会遭受数值精度损失。如果需要精确的限制,请使用`torch.clamp

参数:
  • a -- 要被限制的输入张量。

  • lb -- 下界值。a 的每个元素将被限制为不小于 lb。

  • ub -- 上限值。a 的每个元素将被限制为不大于 ub。

返回:

张量,其中每个元素都被限制在指定的界限内。

evox.utils.jit_fix_operator.clamp_int(a: torch.Tensor, lb: int, ub: int) torch.Tensor[源代码]

将输入张量 a 的整数值限制在给定的下限 (lb) 和上限 (ub) 之间。

此函数确保张量 a 的每个元素不小于 lb 且不大于 `ub

Notice

这是一个修复函数,适用于torch.clamp,因为它尚未在JIT操作符融合中支持。

Warning

这不是torch.clamp的精确复制,如果a是浮点张量,可能会遭受数值精度损失。如果需要精确的限制,请使用`torch.clamp

参数:
  • a -- 要被限制的输入张量。

  • lb -- 下界值。a 的每个元素将被限制为不小于 lb。

  • ub -- 上限值。a 的每个元素将被限制为不大于 ub。

返回:

张量,其中每个元素都被限制在指定的界限内。

evox.utils.jit_fix_operator.clip(a: torch.Tensor) torch.Tensor[源代码]

将输入张量 a 的值裁剪到 [0, 1] 范围内。

注意:此函数调用 `clamp(a, 0, 1)

参数:

a -- 要裁剪的输入张量。

返回:

一个张量,其中每个元素都被限制在 [0, 1] 范围内。

evox.utils.jit_fix_operator.maximum(a: torch.Tensor, b: torch.Tensor) torch.Tensor[源代码]

两个输入张量 ab 的逐元素最大值。

Notice

这是一个修复函数,用于torch.maximum,因为它在JIT操作符融合中尚不支持。

Warning

这并不是对 torch.maximum 的精确复现,如果 ab 是浮点张量,可能会存在数值精度损失。如果需要精确的最大值,请使用 torch.maximum

参数:
  • a -- 第一个输入张量。

  • b -- 第二个输入张量。

返回:

a 和 b 的逐元素最大值。

evox.utils.jit_fix_operator.minimum(a: torch.Tensor, b: torch.Tensor) torch.Tensor[源代码]

两个输入张量 ab 的逐元素最小值。

Notice

这是一个用于 torch.minimum 的修正功能,因为它目前尚未在 JIT 操作符融合中得到支持。

Warning

这并不是 torch.minimum 的精确复制,如果 ab 是浮点张量,可能会受到数值精度损失的影响。如果需要精确的最小值,请使用 torch.minimum

参数:
  • a -- 第一个输入张量。

  • b -- 第二个输入张量。

返回:

a 和 b 的元素级最小值。

evox.utils.jit_fix_operator.maximum_float(a: torch.Tensor, b: float) torch.Tensor[源代码]

输入张量 a 和浮点数 b 的逐元素最大值。

Notice

这是一个修复函数,用于torch.maximum,因为它在JIT操作符融合中尚不支持。

Warning

这不是 torch.maximum 的精确复制,如果 a 是一个浮点张量,可能会受到数值精度损失的影响。如果需要精确的最大值,请使用 torch.maximum

参数:
  • a -- 第一个输入张量。

  • b -- 第二个输入是一个浮点数,它是一个标量值。

返回:

a 和 b 的逐元素最大值。

evox.utils.jit_fix_operator.minimum_float(a: torch.Tensor, b: float) torch.Tensor[源代码]

输入张量 a 和浮点数 b 的逐元素最小值。

Notice

这是一个用于 torch.minimum 的修正功能,因为它目前尚未在 JIT 操作符融合中得到支持。

Warning

这并不是对 torch.minimum 的精确复制,如果 a 是一个浮点张量,可能会出现数值精度损失。如果需要精确的最小值,请使用 torch.minimum

参数:
  • a -- 第一个输入张量。

  • b -- 第二个输入是一个浮点数,它是一个标量值。

返回:

a 和 b 的元素级最小值。

evox.utils.jit_fix_operator.maximum_int(a: torch.Tensor, b: int) torch.Tensor[源代码]

输入张量 a 和整数 b 的逐元素最大值。

Notice

这是一个修复函数,用于torch.maximum,因为它在JIT操作符融合中尚不支持。

Warning

这不是 torch.maximum 的精确复制,如果 a 是一个浮点张量,可能会受到数值精度损失的影响。如果需要精确的最大值,请使用 torch.maximum

参数:
  • a -- 第一个输入张量。

  • b -- 第二个输入是一个整数,它是一个标量值。

返回:

a 和 b 的逐元素最大值。

evox.utils.jit_fix_operator.minimum_int(a: torch.Tensor, b: int) torch.Tensor[源代码]

输入张量 a 和整数 b 的逐元素最小值。

Notice

这是一个用于 torch.minimum 的修正功能,因为它目前尚未在 JIT 操作符融合中得到支持。

Warning

这并不是对 torch.minimum 的精确复制,如果 a 是一个浮点张量,可能会出现数值精度损失。如果需要精确的最小值,请使用 torch.minimum

参数:
  • a -- 第一个输入张量。

  • b -- 第二个输入是一个整数,它是一个标量值。

返回:

a 和 b 的元素级最小值。

evox.utils.jit_fix_operator.lexsort(keys: List[torch.Tensor], dim: int = -1) torch.Tensor[源代码]

对多个张量进行字典序排序,将每个张量视为一个键。

此函数按字典顺序对给定的张量进行排序,首先按第一个键排序,如果第一个键相同,则按第二个键排序,依此类推。它的工作方式类似于NumPy的lexsort,但专为PyTorch张量设计。

参数:
  • keys -- 要排序的张量列表,其中每个张量代表一个排序键。所有张量在指定维度(dim)上必须具有相同的长度。

  • dim -- 执行排序的维度。默认为 -1(最后一个维度)。

返回:

一个张量,包含按字典顺序对输入张量进行排序的索引。这些索引指示排序后张量中元素的顺序。

Example

key1 = torch.tensor([1, 3, 2])
key2 = torch.tensor([9, 7, 8])
sorted_indices = lexsort([key1, key2])
# sorted_indices will contain the indices that sort first by key2,
# and then by key1 in case of ties.

小技巧

您可以使用 torch.unbind 将张量拆分为列表。

evox.utils.jit_fix_operator.nanmin(input_tensor: torch.Tensor, dim: int = -1, keepdim: bool = False)[源代码]

计算张量在指定维度上的最小值,忽略NaN值。

此函数将输入张量中的 NaN 值替换为 infinity,然后在指定维度上计算最小值,有效地忽略 NaN 值。

参数:
  • input_tensor -- 输入张量,可能包含NaN值。它可以是任何形状。

  • dim -- 计算最小值的维度。默认值为-1,对应于最后一个维度。

  • keepdim -- 是否在结果中保留减少的维度。默认值为 False。如果为 True,输出张量将与输入具有相同数量的维度,减少的维度大小将设置为 1。

返回:

一个命名元组,包含两个字段:values (torch.Tensor):一个张量,包含沿指定维度计算的最小值,忽略NaN值。indices (torch.Tensor):一个张量,包含沿指定维度的最小值的索引。返回的张量values和indices将与输入张量具有相同的形状,除了执行操作的维度。

Example

x = torch.tensor([[1.0, 2.0], [float('nan'), 4.0]])
result = nanmin(x, dim=0)
print(result.values)  # Output: tensor([1.0, 2.0])
print(result.indices)  # Output: tensor([0, 0])

备注

  • 在计算最小值之前,通过将 NaN 值替换为 infinity 来忽略它们。

  • 如果一个维度上的所有值都是NaN,那么该维度的结果将是infinity,并且索引将返回为第一个有效索引。

evox.utils.jit_fix_operator.nanmax(input_tensor: torch.Tensor, dim: int = -1, keepdim: bool = False)[源代码]

计算张量在指定维度上的最大值,忽略NaN值。

此函数将输入张量中的 NaN 值替换为 -infinity,然后在指定维度上计算最大值,从而有效地忽略 NaN 值。

参数:
  • input_tensor -- 输入张量,可能包含NaN值。它可以是任何形状。

  • dim -- 计算最大值的维度。默认值是 -1,对应于最后一个维度。

  • keepdim -- 是否在结果中保留减少的维度。默认值为 False。如果为 True,输出张量将与输入具有相同数量的维度,减少的维度大小将设置为 1。

返回:

一个命名元组,包含两个字段:values (torch.Tensor):一个张量,包含沿指定维度计算的最大值,忽略 NaN 值。indices (torch.Tensor):一个张量,包含沿指定维度的最大值的索引。返回的张量 values 和 indices 将与输入张量具有相同的形状,除了执行操作的维度。

Example

x = torch.tensor([[1.0, 2.0], [float('nan'), 4.0]])
result = nanmax(x, dim=0)
print(result.values)  # Output: tensor([1.0, 4.0])
print(result.indices)  # Output: tensor([0, 1])

备注

  • 在计算最大值之前,通过将 NaN 值替换为 -infinity 来忽略它们。

  • 如果一个维度上的所有值都是NaN,那么该维度的结果将是-infinity,并且索引将返回为第一个有效索引。

evox.utils.jit_fix_operator.randint(low: torch.Tensor | int | torch.SymInt, high: torch.Tensor | int | torch.SymInt, size: Sequence[int | torch.SymInt] | torch.Size, dtype: torch.dtype | None = None, device: torch.device | None = None, generator: torch.Generator | None = None)[源代码]

随机生成一个整数张量,范围由指定的低值和高值张量决定,类似于 torch.randint。然而,现在低值和高值是张量。

此函数首先生成一个范围为 [0, 1) 的均匀随机浮点张量,然后使用提供的低值和高值调整范围。

参数:
  • low -- 输入的下界张量(包含)或整数。必须是标量。

  • high -- 输入为上界张量(不包含上界)或整数。必须是标量。

  • size -- 输出张量的目标大小。

  • dtype -- 输出张量所需的数据类型。默认值为 None,表示与 low 或 high 相同。

  • device -- 输出张量的目标设备。默认值 None 表示与 low 或 high 相同。

  • generator -- 要使用的随机数生成器。默认值为 None。

返回:

在指定范围内并具有给定大小的随机整数张量。

Example

high = torch.tensor(8)
randint(0, high, (2, 2))

备注

当与 torch.compile 一起使用,并且 lowhighsize 可以是动态整数(例如另一个张量的大小)时,此函数必须torch.compile(..., dynamic=False) 一起使用。