evox.utils.jit_fix_operator
¶
模块内容¶
函数¶
基于标签张量从张量列表中生成张量的逐元素切换选择运算符。 |
|
将输入张量 |
|
将输入张量 |
|
将输入张量 |
|
将输入张量 |
|
两个输入张量 |
|
两个输入张量 |
|
输入张量 |
|
输入张量 |
|
输入张量 |
|
输入张量 |
|
对多个张量进行字典序排序,将每个张量视为一个键。 |
|
计算张量在指定维度上的最小值,忽略NaN值。 |
|
计算张量在指定维度上的最大值,忽略NaN值。 |
|
随机生成一个整数张量,范围由指定的低值和高值张量决定,类似于 |
API¶
- evox.utils.jit_fix_operator.switch(label: torch.Tensor, values: List[torch.Tensor]) torch.Tensor [源代码]¶
基于标签张量从张量列表中生成张量的逐元素切换选择运算符。
- 参数:
label -- 一个张量,包含用于从张量列表中选择的标签。必须可以广播到其余参数的形状。
values -- 根据标签从中选择的张量列表。列表中的所有张量必须可以广播到相同的形状。
- 返回:
一个张量,其中每个元素都是根据标签张量中的相应元素从张量列表中选择的。
- evox.utils.jit_fix_operator.clamp(a: torch.Tensor, lb: torch.Tensor, ub: torch.Tensor) torch.Tensor [源代码]¶
将输入张量
a
的值限制在给定的下界 (lb
) 和上界 (ub
) 之间。此函数确保张量
a
的每个元素不小于lb
的相应元素且不大于ub
的相应元素。Notice
这是一个修复函数,适用于
torch.clamp
,因为它尚未在JIT操作符融合中支持。Warning
这不是
torch.clamp
的精确复制,如果a
、lb
或ub
是浮点张量,可能会遭受数值精度损失。如果需要精确的限制,请使用`torch.clamp- 参数:
a -- 要被限制的输入张量。
lb -- 下界张量。必须可以广播到 a 的形状。
ub -- 上界张量。必须可以广播到a的形状。
- 返回:
张量,其中每个元素都被限制在指定的界限内。
- evox.utils.jit_fix_operator.clamp_float(a: torch.Tensor, lb: float, ub: float) torch.Tensor [源代码]¶
将输入张量
a
的浮点值限制在给定的下界 (lb
) 和上界 (ub
) 之间。此函数确保张量
a
的每个元素不小于lb
且不大于 `ubNotice
这是一个修复函数,适用于
torch.clamp
,因为它尚未在JIT操作符融合中支持。Warning
这不是
torch.clamp
的精确复制,如果a
是浮点张量,可能会遭受数值精度损失。如果需要精确的限制,请使用`torch.clamp- 参数:
a -- 要被限制的输入张量。
lb -- 下界值。a 的每个元素将被限制为不小于 lb。
ub -- 上限值。a 的每个元素将被限制为不大于 ub。
- 返回:
张量,其中每个元素都被限制在指定的界限内。
- evox.utils.jit_fix_operator.clamp_int(a: torch.Tensor, lb: int, ub: int) torch.Tensor [源代码]¶
将输入张量
a
的整数值限制在给定的下限 (lb
) 和上限 (ub
) 之间。此函数确保张量
a
的每个元素不小于lb
且不大于 `ubNotice
这是一个修复函数,适用于
torch.clamp
,因为它尚未在JIT操作符融合中支持。Warning
这不是
torch.clamp
的精确复制,如果a
是浮点张量,可能会遭受数值精度损失。如果需要精确的限制,请使用`torch.clamp- 参数:
a -- 要被限制的输入张量。
lb -- 下界值。a 的每个元素将被限制为不小于 lb。
ub -- 上限值。a 的每个元素将被限制为不大于 ub。
- 返回:
张量,其中每个元素都被限制在指定的界限内。
- evox.utils.jit_fix_operator.clip(a: torch.Tensor) torch.Tensor [源代码]¶
将输入张量
a
的值裁剪到 [0, 1] 范围内。注意:此函数调用 `clamp(a, 0, 1)
- 参数:
a -- 要裁剪的输入张量。
- 返回:
一个张量,其中每个元素都被限制在 [0, 1] 范围内。
- evox.utils.jit_fix_operator.maximum(a: torch.Tensor, b: torch.Tensor) torch.Tensor [源代码]¶
两个输入张量
a
和b
的逐元素最大值。Notice
这是一个修复函数,用于
torch.maximum
,因为它在JIT操作符融合中尚不支持。Warning
这并不是对
torch.maximum
的精确复现,如果a
或b
是浮点张量,可能会存在数值精度损失。如果需要精确的最大值,请使用torch.maximum
。- 参数:
a -- 第一个输入张量。
b -- 第二个输入张量。
- 返回:
a 和 b 的逐元素最大值。
- evox.utils.jit_fix_operator.minimum(a: torch.Tensor, b: torch.Tensor) torch.Tensor [源代码]¶
两个输入张量
a
和b
的逐元素最小值。Notice
这是一个用于
torch.minimum
的修正功能,因为它目前尚未在 JIT 操作符融合中得到支持。Warning
这并不是
torch.minimum
的精确复制,如果a
或b
是浮点张量,可能会受到数值精度损失的影响。如果需要精确的最小值,请使用torch.minimum
。- 参数:
a -- 第一个输入张量。
b -- 第二个输入张量。
- 返回:
a 和 b 的元素级最小值。
- evox.utils.jit_fix_operator.maximum_float(a: torch.Tensor, b: float) torch.Tensor [源代码]¶
输入张量
a
和浮点数b
的逐元素最大值。Notice
这是一个修复函数,用于
torch.maximum
,因为它在JIT操作符融合中尚不支持。Warning
这不是
torch.maximum
的精确复制,如果a
是一个浮点张量,可能会受到数值精度损失的影响。如果需要精确的最大值,请使用torch.maximum
。- 参数:
a -- 第一个输入张量。
b -- 第二个输入是一个浮点数,它是一个标量值。
- 返回:
a 和 b 的逐元素最大值。
- evox.utils.jit_fix_operator.minimum_float(a: torch.Tensor, b: float) torch.Tensor [源代码]¶
输入张量
a
和浮点数b
的逐元素最小值。Notice
这是一个用于
torch.minimum
的修正功能,因为它目前尚未在 JIT 操作符融合中得到支持。Warning
这并不是对
torch.minimum
的精确复制,如果a
是一个浮点张量,可能会出现数值精度损失。如果需要精确的最小值,请使用torch.minimum
。- 参数:
a -- 第一个输入张量。
b -- 第二个输入是一个浮点数,它是一个标量值。
- 返回:
a 和 b 的元素级最小值。
- evox.utils.jit_fix_operator.maximum_int(a: torch.Tensor, b: int) torch.Tensor [源代码]¶
输入张量
a
和整数b
的逐元素最大值。Notice
这是一个修复函数,用于
torch.maximum
,因为它在JIT操作符融合中尚不支持。Warning
这不是
torch.maximum
的精确复制,如果a
是一个浮点张量,可能会受到数值精度损失的影响。如果需要精确的最大值,请使用torch.maximum
。- 参数:
a -- 第一个输入张量。
b -- 第二个输入是一个整数,它是一个标量值。
- 返回:
a 和 b 的逐元素最大值。
- evox.utils.jit_fix_operator.minimum_int(a: torch.Tensor, b: int) torch.Tensor [源代码]¶
输入张量
a
和整数b
的逐元素最小值。Notice
这是一个用于
torch.minimum
的修正功能,因为它目前尚未在 JIT 操作符融合中得到支持。Warning
这并不是对
torch.minimum
的精确复制,如果a
是一个浮点张量,可能会出现数值精度损失。如果需要精确的最小值,请使用torch.minimum
。- 参数:
a -- 第一个输入张量。
b -- 第二个输入是一个整数,它是一个标量值。
- 返回:
a 和 b 的元素级最小值。
- evox.utils.jit_fix_operator.lexsort(keys: List[torch.Tensor], dim: int = -1) torch.Tensor [源代码]¶
对多个张量进行字典序排序,将每个张量视为一个键。
此函数按字典顺序对给定的张量进行排序,首先按第一个键排序,如果第一个键相同,则按第二个键排序,依此类推。它的工作方式类似于NumPy的
lexsort
,但专为PyTorch张量设计。- 参数:
keys -- 要排序的张量列表,其中每个张量代表一个排序键。所有张量在指定维度(dim)上必须具有相同的长度。
dim -- 执行排序的维度。默认为 -1(最后一个维度)。
- 返回:
一个张量,包含按字典顺序对输入张量进行排序的索引。这些索引指示排序后张量中元素的顺序。
Example
key1 = torch.tensor([1, 3, 2]) key2 = torch.tensor([9, 7, 8]) sorted_indices = lexsort([key1, key2]) # sorted_indices will contain the indices that sort first by key2, # and then by key1 in case of ties.
小技巧
您可以使用
torch.unbind
将张量拆分为列表。
- evox.utils.jit_fix_operator.nanmin(input_tensor: torch.Tensor, dim: int = -1, keepdim: bool = False)[源代码]¶
计算张量在指定维度上的最小值,忽略NaN值。
此函数将输入张量中的
NaN
值替换为infinity
,然后在指定维度上计算最小值,有效地忽略NaN
值。- 参数:
input_tensor -- 输入张量,可能包含NaN值。它可以是任何形状。
dim -- 计算最小值的维度。默认值为-1,对应于最后一个维度。
keepdim -- 是否在结果中保留减少的维度。默认值为 False。如果为 True,输出张量将与输入具有相同数量的维度,减少的维度大小将设置为 1。
- 返回:
一个命名元组,包含两个字段:values (torch.Tensor):一个张量,包含沿指定维度计算的最小值,忽略NaN值。indices (torch.Tensor):一个张量,包含沿指定维度的最小值的索引。返回的张量values和indices将与输入张量具有相同的形状,除了执行操作的维度。
Example
x = torch.tensor([[1.0, 2.0], [float('nan'), 4.0]]) result = nanmin(x, dim=0) print(result.values) # Output: tensor([1.0, 2.0]) print(result.indices) # Output: tensor([0, 0])
备注
在计算最小值之前,通过将
NaN
值替换为infinity
来忽略它们。如果一个维度上的所有值都是
NaN
,那么该维度的结果将是infinity
,并且索引将返回为第一个有效索引。
- evox.utils.jit_fix_operator.nanmax(input_tensor: torch.Tensor, dim: int = -1, keepdim: bool = False)[源代码]¶
计算张量在指定维度上的最大值,忽略NaN值。
此函数将输入张量中的
NaN
值替换为-infinity
,然后在指定维度上计算最大值,从而有效地忽略NaN
值。- 参数:
input_tensor -- 输入张量,可能包含NaN值。它可以是任何形状。
dim -- 计算最大值的维度。默认值是 -1,对应于最后一个维度。
keepdim -- 是否在结果中保留减少的维度。默认值为 False。如果为 True,输出张量将与输入具有相同数量的维度,减少的维度大小将设置为 1。
- 返回:
一个命名元组,包含两个字段:values (torch.Tensor):一个张量,包含沿指定维度计算的最大值,忽略 NaN 值。indices (torch.Tensor):一个张量,包含沿指定维度的最大值的索引。返回的张量 values 和 indices 将与输入张量具有相同的形状,除了执行操作的维度。
Example
x = torch.tensor([[1.0, 2.0], [float('nan'), 4.0]]) result = nanmax(x, dim=0) print(result.values) # Output: tensor([1.0, 4.0]) print(result.indices) # Output: tensor([0, 1])
备注
在计算最大值之前,通过将
NaN
值替换为-infinity
来忽略它们。如果一个维度上的所有值都是
NaN
,那么该维度的结果将是-infinity
,并且索引将返回为第一个有效索引。
- evox.utils.jit_fix_operator.randint(low: torch.Tensor | int | torch.SymInt, high: torch.Tensor | int | torch.SymInt, size: Sequence[int | torch.SymInt] | torch.Size, dtype: torch.dtype | None = None, device: torch.device | None = None, generator: torch.Generator | None = None)[源代码]¶
随机生成一个整数张量,范围由指定的低值和高值张量决定,类似于
torch.randint
。然而,现在低值和高值是张量。此函数首先生成一个范围为 [0, 1) 的均匀随机浮点张量,然后使用提供的低值和高值调整范围。
- 参数:
low -- 输入的下界张量(包含)或整数。必须是标量。
high -- 输入为上界张量(不包含上界)或整数。必须是标量。
size -- 输出张量的目标大小。
dtype -- 输出张量所需的数据类型。默认值为 None,表示与 low 或 high 相同。
device -- 输出张量的目标设备。默认值 None 表示与 low 或 high 相同。
generator -- 要使用的随机数生成器。默认值为 None。
- 返回:
在指定范围内并具有给定大小的随机整数张量。
Example
high = torch.tensor(8) randint(0, high, (2, 2))
备注
当与
torch.compile
一起使用,并且low
、high
或size
可以是动态整数(例如另一个张量的大小)时,此函数必须与torch.compile(..., dynamic=False)
一起使用。