evox.utils.jit_fix_operator
#
模块内容#
函数#
基于标签张量从张量列表中生成张量的逐元素切换选择运算符。 |
|
将输入张量 |
|
将输入张量 |
|
将输入张量 |
|
将输入张量 |
|
两个输入张量 |
|
两个输入张量 |
|
输入张量 |
|
输入张量 |
|
输入张量 |
|
输入张量 |
|
对多个张量进行字典序排序,将每个张量视为一个键。 |
|
计算张量在指定维度上的最小值,忽略NaN值。 |
|
计算张量在指定维度上的最大值,忽略NaN值。 |
API#
- evox.utils.jit_fix_operator.switch(label: torch.Tensor, values: List[torch.Tensor]) torch.Tensor [源代码]#
基于标签张量从张量列表中生成张量的逐元素切换选择运算符。
- 参数:
label -- 一个张量,包含用于从张量列表中选择的标签。必须可以广播到其余参数的形状。
values -- 根据标签从中选择的张量列表。列表中的所有张量必须可以广播到相同的形状。
- 返回:
一个张量,其中每个元素都是根据标签张量中的相应元素从张量列表中选择的。
- evox.utils.jit_fix_operator.clamp(a: torch.Tensor, lb: torch.Tensor, ub: torch.Tensor) torch.Tensor [源代码]#
将输入张量
a
的值限制在给定的下界 (lb
) 和上界 (ub
) 之间。此函数确保张量
a
的每个元素不小于lb
的相应元素且不大于ub
的相应元素。Notice
这是一个用于
torch.clamp
的修复函数,因为它在JIT操作符融合中不被支持。这不是
torch.clamp
的精确复制,如果a
、lb
或ub
是浮点张量,可能会遭受数值精度损失。如果需要精确的限制,请使用`torch.clamp
- 参数:
a -- 要被限制的输入张量。
lb -- 下界张量。必须可以广播到 a 的形状。
ub -- 上界张量。必须可以广播到a的形状。
- 返回:
张量,其中每个元素都被限制在指定的界限内。
- evox.utils.jit_fix_operator.clamp_float(a: torch.Tensor, lb: float, ub: float) torch.Tensor [源代码]#
将输入张量
a
的浮点值限制在给定的下界 (lb
) 和上界 (ub
) 之间。此函数确保张量
a
的每个元素不小于lb
且不大于 `ubNotice
这是一个用于
torch.clamp
的修复函数,因为它在JIT操作符融合中不被支持。这不是
torch.clamp
的精确复制,如果a
是浮点张量,可能会遭受数值精度损失。如果需要精确的限制,请使用`torch.clamp
- 参数:
a -- 要被限制的输入张量。
lb -- 下界值。a 的每个元素将被限制为不小于 lb。
ub -- 上限值。a 的每个元素将被限制为不大于 ub。
- 返回:
张量,其中每个元素都被限制在指定的界限内。
- evox.utils.jit_fix_operator.clamp_int(a: torch.Tensor, lb: int, ub: int) torch.Tensor [源代码]#
将输入张量
a
的整数值限制在给定的下限 (lb
) 和上限 (ub
) 之间。此函数确保张量
a
的每个元素不小于lb
且不大于 `ubNotice
这是一个用于
torch.clamp
的修复函数,因为它在JIT操作符融合中不被支持。这不是
torch.clamp
的精确复制,如果a
是一个整数张量,可能会遭受数值精度损失。如果需要精确的限制,请使用`torch.clamp
- 参数:
a -- 要被限制的输入张量。
lb -- 下界值。a 的每个元素将被限制为不小于 lb。
ub -- 上限值。a 的每个元素将被限制为不大于 ub。
- 返回:
张量,其中每个元素都被限制在指定的界限内。
- evox.utils.jit_fix_operator.clip(a: torch.Tensor) torch.Tensor [源代码]#
将输入张量
a
的值裁剪到 [0, 1] 范围内。注意:此函数调用 `clamp(a, 0, 1)
- 参数:
a -- 要裁剪的输入张量。
- 返回:
一个张量,其中每个元素都被限制在 [0, 1] 范围内。
- evox.utils.jit_fix_operator.maximum(a: torch.Tensor, b: torch.Tensor) torch.Tensor [源代码]#
两个输入张量
a
和b
的逐元素最大值。注意:这是一个用于
torch.maximum
的修复函数,因为它在JIT操作符融合中不支持。- 参数:
a -- 第一个输入张量。
b -- 第二个输入张量。
- 返回:
a 和 b 的逐元素最大值。
- evox.utils.jit_fix_operator.minimum(a: torch.Tensor, b: torch.Tensor) torch.Tensor [源代码]#
两个输入张量
a
和b
的逐元素最小值。注意:这是一个用于
torch.minimum
的修复函数,因为它在JIT操作符融合中不支持。- 参数:
a -- 第一个输入张量。
b -- 第二个输入张量。
- 返回:
a 和 b 的元素级最小值。
- evox.utils.jit_fix_operator.maximum_float(a: torch.Tensor, b: float) torch.Tensor [源代码]#
输入张量
a
和浮点数b
的逐元素最大值。注意:这是一个用于
torch.maximum
的修复函数,因为它在JIT操作符融合中不支持。- 参数:
a -- 第一个输入张量。
b -- 第二个输入是一个浮点数,它是一个标量值。
- 返回:
a 和 b 的逐元素最大值。
- evox.utils.jit_fix_operator.minimum_float(a: torch.Tensor, b: float) torch.Tensor [源代码]#
输入张量
a
和浮点数b
的逐元素最小值。注意:这是一个用于
torch.minimum
的修复函数,因为它在JIT操作符融合中不支持。- 参数:
a -- 第一个输入张量。
b -- 第二个输入是一个浮点数,它是一个标量值。
- 返回:
a 和 b 的元素级最小值。
- evox.utils.jit_fix_operator.maximum_int(a: torch.Tensor, b: int) torch.Tensor [源代码]#
输入张量
a
和整数b
的逐元素最大值。注意:这是一个用于
torch.maximum
的修复函数,因为它在JIT操作符融合中不支持。- 参数:
a -- 第一个输入张量。
b -- 第二个输入是一个整数,它是一个标量值。
- 返回:
a 和 b 的逐元素最大值。
- evox.utils.jit_fix_operator.minimum_int(a: torch.Tensor, b: int) torch.Tensor [源代码]#
输入张量
a
和整数b
的逐元素最小值。注意:这是一个用于
torch.minimum
的修复函数,因为它在JIT操作符融合中不支持。- 参数:
a -- 第一个输入张量。
b -- 第二个输入是一个整数,它是一个标量值。
- 返回:
a 和 b 的元素级最小值。
- evox.utils.jit_fix_operator.lexsort(keys: List[torch.Tensor], dim: int = -1) torch.Tensor [源代码]#
对多个张量进行字典序排序,将每个张量视为一个键。
此函数按字典顺序对给定的张量进行排序,首先按第一个键排序,如果第一个键相同,则按第二个键排序,依此类推。它的工作方式类似于NumPy的
lexsort
,但专为PyTorch张量设计。- 参数:
keys -- 要排序的张量列表,其中每个张量代表一个排序键。所有张量在指定维度(dim)上必须具有相同的长度。
dim -- 执行排序的维度。默认为 -1(最后一个维度)。
- 返回:
一个张量,包含按字典顺序对输入张量进行排序的索引。这些索引指示排序后张量中元素的顺序。
Example
key1 = torch.tensor([1, 3, 2]) key2 = torch.tensor([9, 7, 8]) sorted_indices = lexsort([key1, key2]) # sorted_indices will contain the indices that sort first by key2, # and then by key1 in case of ties.
- 注意:
您可以使用
torch.unbind
将张量拆分为列表。
- evox.utils.jit_fix_operator.nanmin(input_tensor: torch.Tensor, dim: int = -1, keepdim: bool = False)[源代码]#
计算张量在指定维度上的最小值,忽略NaN值。
此函数将输入张量中的
NaN
值替换为infinity
,然后在指定维度上计算最小值,有效地忽略NaN
值。- 参数:
input_tensor -- 输入张量,可能包含NaN值。它可以是任何形状。
dim -- 计算最小值的维度。默认值为-1,对应于最后一个维度。
keepdim -- 是否在结果中保留减少的维度。默认值为 False。如果为 True,输出张量将与输入具有相同数量的维度,减少的维度大小将设置为 1。
- 返回:
一个命名元组,包含两个字段:values (torch.Tensor):一个张量,包含沿指定维度计算的最小值,忽略NaN值。indices (torch.Tensor):一个张量,包含沿指定维度的最小值的索引。返回的张量values和indices将与输入张量具有相同的形状,除了执行操作的维度。
Example
x = torch.tensor([[1.0, 2.0], [float('nan'), 4.0]]) result = nanmin(x, dim=0) print(result.values) # Output: tensor([1.0, 2.0]) print(result.indices) # Output: tensor([0, 0])
备注
在计算最小值之前,通过将
NaN
值替换为infinity
来忽略它们。如果一个维度上的所有值都是
NaN
,那么该维度的结果将是infinity
,并且索引将返回为第一个有效索引。
- evox.utils.jit_fix_operator.nanmax(input_tensor: torch.Tensor, dim: int = -1, keepdim: bool = False)[源代码]#
计算张量在指定维度上的最大值,忽略NaN值。
此函数将输入张量中的
NaN
值替换为-infinity
,然后在指定维度上计算最大值,从而有效地忽略NaN
值。- 参数:
input_tensor -- 输入张量,可能包含NaN值。它可以是任何形状。
dim -- 计算最大值的维度。默认值是 -1,对应于最后一个维度。
keepdim -- 是否在结果中保留减少的维度。默认值为 False。如果为 True,输出张量将与输入具有相同数量的维度,减少的维度大小将设置为 1。
- 返回:
一个命名元组,包含两个字段:values (torch.Tensor):一个张量,包含沿指定维度计算的最大值,忽略 NaN 值。indices (torch.Tensor):一个张量,包含沿指定维度的最大值的索引。返回的张量 values 和 indices 将与输入张量具有相同的形状,除了执行操作的维度。
Example
x = torch.tensor([[1.0, 2.0], [float('nan'), 4.0]]) result = nanmax(x, dim=0) print(result.values) # Output: tensor([1.0, 4.0]) print(result.indices) # Output: tensor([0, 1])
备注
在计算最大值之前,通过将
NaN
值替换为-infinity
来忽略它们。如果一个维度上的所有值都是
NaN
,那么该维度的结果将是-infinity
,并且索引将返回为第一个有效索引。