evox.utils.jit_fix_operator#

模块内容#

函数#

switch

基于标签张量从张量列表中生成张量的逐元素切换选择运算符。

clamp

将输入张量 a 的值限制在给定的下界 (lb) 和上界 (ub) 之间。

clamp_float

将输入张量 a 的浮点值限制在给定的下界 (lb) 和上界 (ub) 之间。

clamp_int

将输入张量 a 的整数值限制在给定的下限 (lb) 和上限 (ub) 之间。

clip

将输入张量 a 的值裁剪到 [0, 1] 范围内。

maximum

两个输入张量 ab 的逐元素最大值。

minimum

两个输入张量 ab 的逐元素最小值。

maximum_float

输入张量 a 和浮点数 b 的逐元素最大值。

minimum_float

输入张量 a 和浮点数 b 的逐元素最小值。

maximum_int

输入张量 a 和整数 b 的逐元素最大值。

minimum_int

输入张量 a 和整数 b 的逐元素最小值。

lexsort

对多个张量进行字典序排序,将每个张量视为一个键。

nanmin

计算张量在指定维度上的最小值,忽略NaN值。

nanmax

计算张量在指定维度上的最大值,忽略NaN值。

API#

evox.utils.jit_fix_operator.switch(label: torch.Tensor, values: List[torch.Tensor]) torch.Tensor[源代码]#

基于标签张量从张量列表中生成张量的逐元素切换选择运算符。

参数:
  • label -- 一个张量,包含用于从张量列表中选择的标签。必须可以广播到其余参数的形状。

  • values -- 根据标签从中选择的张量列表。列表中的所有张量必须可以广播到相同的形状。

返回:

一个张量,其中每个元素都是根据标签张量中的相应元素从张量列表中选择的。

evox.utils.jit_fix_operator.clamp(a: torch.Tensor, lb: torch.Tensor, ub: torch.Tensor) torch.Tensor[源代码]#

将输入张量 a 的值限制在给定的下界 (lb) 和上界 (ub) 之间。

此函数确保张量 a 的每个元素不小于 lb 的相应元素且不大于 ub 的相应元素。

Notice

  1. 这是一个用于torch.clamp的修复函数,因为它在JIT操作符融合中不被支持。

  2. 这不是torch.clamp的精确复制,如果albub是浮点张量,可能会遭受数值精度损失。如果需要精确的限制,请使用`torch.clamp

参数:
  • a -- 要被限制的输入张量。

  • lb -- 下界张量。必须可以广播到 a 的形状。

  • ub -- 上界张量。必须可以广播到a的形状。

返回:

张量,其中每个元素都被限制在指定的界限内。

evox.utils.jit_fix_operator.clamp_float(a: torch.Tensor, lb: float, ub: float) torch.Tensor[源代码]#

将输入张量 a 的浮点值限制在给定的下界 (lb) 和上界 (ub) 之间。

此函数确保张量 a 的每个元素不小于 lb 且不大于 `ub

Notice

  1. 这是一个用于torch.clamp的修复函数,因为它在JIT操作符融合中不被支持。

  2. 这不是torch.clamp的精确复制,如果a是浮点张量,可能会遭受数值精度损失。如果需要精确的限制,请使用`torch.clamp

参数:
  • a -- 要被限制的输入张量。

  • lb -- 下界值。a 的每个元素将被限制为不小于 lb。

  • ub -- 上限值。a 的每个元素将被限制为不大于 ub。

返回:

张量,其中每个元素都被限制在指定的界限内。

evox.utils.jit_fix_operator.clamp_int(a: torch.Tensor, lb: int, ub: int) torch.Tensor[源代码]#

将输入张量 a 的整数值限制在给定的下限 (lb) 和上限 (ub) 之间。

此函数确保张量 a 的每个元素不小于 lb 且不大于 `ub

Notice

  1. 这是一个用于torch.clamp的修复函数,因为它在JIT操作符融合中不被支持。

  2. 这不是torch.clamp的精确复制,如果a是一个整数张量,可能会遭受数值精度损失。如果需要精确的限制,请使用`torch.clamp

参数:
  • a -- 要被限制的输入张量。

  • lb -- 下界值。a 的每个元素将被限制为不小于 lb。

  • ub -- 上限值。a 的每个元素将被限制为不大于 ub。

返回:

张量,其中每个元素都被限制在指定的界限内。

evox.utils.jit_fix_operator.clip(a: torch.Tensor) torch.Tensor[源代码]#

将输入张量 a 的值裁剪到 [0, 1] 范围内。

注意:此函数调用 `clamp(a, 0, 1)

参数:

a -- 要裁剪的输入张量。

返回:

一个张量,其中每个元素都被限制在 [0, 1] 范围内。

evox.utils.jit_fix_operator.maximum(a: torch.Tensor, b: torch.Tensor) torch.Tensor[源代码]#

两个输入张量 ab 的逐元素最大值。

注意:这是一个用于torch.maximum的修复函数,因为它在JIT操作符融合中不支持。

参数:
  • a -- 第一个输入张量。

  • b -- 第二个输入张量。

返回:

a 和 b 的逐元素最大值。

evox.utils.jit_fix_operator.minimum(a: torch.Tensor, b: torch.Tensor) torch.Tensor[源代码]#

两个输入张量 ab 的逐元素最小值。

注意:这是一个用于torch.minimum的修复函数,因为它在JIT操作符融合中不支持。

参数:
  • a -- 第一个输入张量。

  • b -- 第二个输入张量。

返回:

a 和 b 的元素级最小值。

evox.utils.jit_fix_operator.maximum_float(a: torch.Tensor, b: float) torch.Tensor[源代码]#

输入张量 a 和浮点数 b 的逐元素最大值。

注意:这是一个用于torch.maximum的修复函数,因为它在JIT操作符融合中不支持。

参数:
  • a -- 第一个输入张量。

  • b -- 第二个输入是一个浮点数,它是一个标量值。

返回:

a 和 b 的逐元素最大值。

evox.utils.jit_fix_operator.minimum_float(a: torch.Tensor, b: float) torch.Tensor[源代码]#

输入张量 a 和浮点数 b 的逐元素最小值。

注意:这是一个用于torch.minimum的修复函数,因为它在JIT操作符融合中不支持。

参数:
  • a -- 第一个输入张量。

  • b -- 第二个输入是一个浮点数,它是一个标量值。

返回:

a 和 b 的元素级最小值。

evox.utils.jit_fix_operator.maximum_int(a: torch.Tensor, b: int) torch.Tensor[源代码]#

输入张量 a 和整数 b 的逐元素最大值。

注意:这是一个用于torch.maximum的修复函数,因为它在JIT操作符融合中不支持。

参数:
  • a -- 第一个输入张量。

  • b -- 第二个输入是一个整数,它是一个标量值。

返回:

a 和 b 的逐元素最大值。

evox.utils.jit_fix_operator.minimum_int(a: torch.Tensor, b: int) torch.Tensor[源代码]#

输入张量 a 和整数 b 的逐元素最小值。

注意:这是一个用于torch.minimum的修复函数,因为它在JIT操作符融合中不支持。

参数:
  • a -- 第一个输入张量。

  • b -- 第二个输入是一个整数,它是一个标量值。

返回:

a 和 b 的元素级最小值。

evox.utils.jit_fix_operator.lexsort(keys: List[torch.Tensor], dim: int = -1) torch.Tensor[源代码]#

对多个张量进行字典序排序,将每个张量视为一个键。

此函数按字典顺序对给定的张量进行排序,首先按第一个键排序,如果第一个键相同,则按第二个键排序,依此类推。它的工作方式类似于NumPy的lexsort,但专为PyTorch张量设计。

参数:
  • keys -- 要排序的张量列表,其中每个张量代表一个排序键。所有张量在指定维度(dim)上必须具有相同的长度。

  • dim -- 执行排序的维度。默认为 -1(最后一个维度)。

返回:

一个张量,包含按字典顺序对输入张量进行排序的索引。这些索引指示排序后张量中元素的顺序。

Example

key1 = torch.tensor([1, 3, 2])
key2 = torch.tensor([9, 7, 8])
sorted_indices = lexsort([key1, key2])
# sorted_indices will contain the indices that sort first by key2,
# and then by key1 in case of ties.
注意:

您可以使用 torch.unbind 将张量拆分为列表。

evox.utils.jit_fix_operator.nanmin(input_tensor: torch.Tensor, dim: int = -1, keepdim: bool = False)[源代码]#

计算张量在指定维度上的最小值,忽略NaN值。

此函数将输入张量中的 NaN 值替换为 infinity,然后在指定维度上计算最小值,有效地忽略 NaN 值。

参数:
  • input_tensor -- 输入张量,可能包含NaN值。它可以是任何形状。

  • dim -- 计算最小值的维度。默认值为-1,对应于最后一个维度。

  • keepdim -- 是否在结果中保留减少的维度。默认值为 False。如果为 True,输出张量将与输入具有相同数量的维度,减少的维度大小将设置为 1。

返回:

一个命名元组,包含两个字段:values (torch.Tensor):一个张量,包含沿指定维度计算的最小值,忽略NaN值。indices (torch.Tensor):一个张量,包含沿指定维度的最小值的索引。返回的张量values和indices将与输入张量具有相同的形状,除了执行操作的维度。

Example

x = torch.tensor([[1.0, 2.0], [float('nan'), 4.0]])
result = nanmin(x, dim=0)
print(result.values)  # Output: tensor([1.0, 2.0])
print(result.indices)  # Output: tensor([0, 0])

备注

  • 在计算最小值之前,通过将 NaN 值替换为 infinity 来忽略它们。

  • 如果一个维度上的所有值都是NaN,那么该维度的结果将是infinity,并且索引将返回为第一个有效索引。

evox.utils.jit_fix_operator.nanmax(input_tensor: torch.Tensor, dim: int = -1, keepdim: bool = False)[源代码]#

计算张量在指定维度上的最大值,忽略NaN值。

此函数将输入张量中的 NaN 值替换为 -infinity,然后在指定维度上计算最大值,从而有效地忽略 NaN 值。

参数:
  • input_tensor -- 输入张量,可能包含NaN值。它可以是任何形状。

  • dim -- 计算最大值的维度。默认值是 -1,对应于最后一个维度。

  • keepdim -- 是否在结果中保留减少的维度。默认值为 False。如果为 True,输出张量将与输入具有相同数量的维度,减少的维度大小将设置为 1。

返回:

一个命名元组,包含两个字段:values (torch.Tensor):一个张量,包含沿指定维度计算的最大值,忽略 NaN 值。indices (torch.Tensor):一个张量,包含沿指定维度的最大值的索引。返回的张量 values 和 indices 将与输入张量具有相同的形状,除了执行操作的维度。

Example

x = torch.tensor([[1.0, 2.0], [float('nan'), 4.0]])
result = nanmax(x, dim=0)
print(result.values)  # Output: tensor([1.0, 4.0])
print(result.indices)  # Output: tensor([0, 1])

备注

  • 在计算最大值之前,通过将 NaN 值替换为 -infinity 来忽略它们。

  • 如果一个维度上的所有值都是NaN,那么该维度的结果将是-infinity,并且索引将返回为第一个有效索引。