evox.utils.jit_fix_operator
#
Module Contents#
Functions#
Element-wise switch select operator that generates a tensor from a list of tensors based on the label tensor. |
|
Clamp the values of the input tensor |
|
Clamp the float values of the input tensor |
|
Clamp the int values of the input tensor |
|
Clip the values of the input tensor |
|
Element-wise maximum of two input tensors |
|
Element-wise minimum of two input tensors |
|
Element-wise maximum of input tensor |
|
Element-wise minimum of input tensor |
|
Element-wise maximum of input tensor |
|
Element-wise minimum of input tensor |
|
Perform lexicographical sorting of multiple tensors, considering each tensor as a key. |
|
Compute the minimum of a tensor along a specified dimension, ignoring NaN values. |
|
Compute the maximum of a tensor along a specified dimension, ignoring NaN values. |
API#
- evox.utils.jit_fix_operator.switch(label: torch.Tensor, values: List[torch.Tensor]) torch.Tensor [source]#
Element-wise switch select operator that generates a tensor from a list of tensors based on the label tensor.
- Parameters:
label – A tensor containing labels used to select from the list of tensors. Must be broadcastable to the shape of rest arguments.
values – A list of tensors from which one is selected based on the label. All tensors in the list must be broadcastable to the same shape.
- Returns:
A tensor where each element is selected from the list of tensors based on the corresponding element in the label tensor.
- evox.utils.jit_fix_operator.clamp(a: torch.Tensor, lb: torch.Tensor, ub: torch.Tensor) torch.Tensor [source]#
Clamp the values of the input tensor
a
to be within the given lower (lb
) and upper (ub
) bounds.This function ensures that each element of the tensor
a
is not less than the corresponding element oflb
and not greater than the corresponding element ofub
.Notice
This is a fix function for
torch.clamp
since it is not supported in JIT operator fusion.This is NOT a precise replication of
torch.clamp
ifa
,lb
orub
is a float tensor and may suffer from numerical precision losses. Please usetorch.clamp
instead if a precise clamp is required.
- Parameters:
a – The input tensor to be clamped.
lb – The lower bound tensor. Must be broadcastable to the shape of
a
.ub – The upper bound tensor. Must be broadcastable to the shape of
a
.
- Returns:
A tensor where each element is clamped to be within the specified bounds.
- evox.utils.jit_fix_operator.clamp_float(a: torch.Tensor, lb: float, ub: float) torch.Tensor [source]#
Clamp the float values of the input tensor
a
to be within the given lower (lb
) and upper (ub
) bounds.This function ensures that each element of the tensor
a
is not less thanlb
and not greater thanub
.Notice
This is a fix function for
torch.clamp
since it is not supported in JIT operator fusion.This is NOT a precise replication of
torch.clamp
ifa
is a float tensor and may suffer from numerical precision losses. Please usetorch.clamp
instead if a precise clamp is required.
- Parameters:
a – The input tensor to be clamped.
lb – The lower bound value. Each element of
a
will be clamped to be not less thanlb
.ub – The upper bound value. Each element of
a
will be clamped to be not greater thanub
.
- Returns:
A tensor where each element is clamped to be within the specified bounds.
- evox.utils.jit_fix_operator.clamp_int(a: torch.Tensor, lb: int, ub: int) torch.Tensor [source]#
Clamp the int values of the input tensor
a
to be within the given lower (lb
) and upper (ub
) bounds.This function ensures that each element of the tensor
a
is not less thanlb
and not greater thanub
.Notice
This is a fix function for
torch.clamp
since it is not supported in JIT operator fusion.This is NOT a precise replication of
torch.clamp
ifa
is a int tensor and may suffer from numerical precision losses. Please usetorch.clamp
instead if a precise clamp is required.
- Parameters:
a – The input tensor to be clamped.
lb – The lower bound value. Each element of
a
will be clamped to be not less thanlb
.ub – The upper bound value. Each element of
a
will be clamped to be not greater thanub
.
- Returns:
A tensor where each element is clamped to be within the specified bounds.
- evox.utils.jit_fix_operator.clip(a: torch.Tensor) torch.Tensor [source]#
Clip the values of the input tensor
a
to be within the range [0, 1].Notice: This function invokes
clamp(a, 0, 1)
.- Parameters:
a – The input tensor to be clipped.
- Returns:
A tensor where each element is clipped to be within [0, 1].
- evox.utils.jit_fix_operator.maximum(a: torch.Tensor, b: torch.Tensor) torch.Tensor [source]#
Element-wise maximum of two input tensors
a
andb
.Notice: This is a fix function for [
torch.maximum
](https://pytorch.org/docs/stable/generated/torch.maximum.html] since it is not supported in JIT operator fusion.- Parameters:
a – The first input tensor.
b – The second input tensor.
- Returns:
The element-wise maximum of
a
andb
.
- evox.utils.jit_fix_operator.minimum(a: torch.Tensor, b: torch.Tensor) torch.Tensor [source]#
Element-wise minimum of two input tensors
a
andb
.Notice: This is a fix function for [
torch.minimum
](https://pytorch.org/docs/stable/generated/torch.minimum.html] since it is not supported in JIT operator fusion.- Parameters:
a – The first input tensor.
b – The second input tensor.
- Returns:
The element-wise minimum of
a
andb
.
- evox.utils.jit_fix_operator.maximum_float(a: torch.Tensor, b: float) torch.Tensor [source]#
Element-wise maximum of input tensor
a
and floatb
.Notice: This is a fix function for [
torch.maximum
](https://pytorch.org/docs/stable/generated/torch.maximum.html] since it is not supported in JIT operator fusion.- Parameters:
a – The first input tensor.
b – The second input float, which is a scalar value.
- Returns:
The element-wise maximum of
a
andb
.
- evox.utils.jit_fix_operator.minimum_float(a: torch.Tensor, b: float) torch.Tensor [source]#
Element-wise minimum of input tensor
a
and floatb
.Notice: This is a fix function for
torch.minimum
since it is not supported in JIT operator fusion.- Parameters:
a – The first input tensor.
b – The second input float, which is a scalar value.
- Returns:
The element-wise minimum of
a
andb
.
- evox.utils.jit_fix_operator.maximum_int(a: torch.Tensor, b: int) torch.Tensor [source]#
Element-wise maximum of input tensor
a
and intb
.Notice: This is a fix function for [
torch.maximum
](https://pytorch.org/docs/stable/generated/torch.maximum.html] since it is not supported in JIT operator fusion.- Parameters:
a – The first input tensor.
b – The second input int, which is a scalar value.
- Returns:
The element-wise maximum of
a
andb
.
- evox.utils.jit_fix_operator.minimum_int(a: torch.Tensor, b: int) torch.Tensor [source]#
Element-wise minimum of input tensor
a
and intb
.Notice: This is a fix function for
torch.minimum
since it is not supported in JIT operator fusion.- Parameters:
a – The first input tensor.
b – The second input int, which is a scalar value.
- Returns:
The element-wise minimum of
a
andb
.
- evox.utils.jit_fix_operator.lexsort(keys: List[torch.Tensor], dim: int = -1) torch.Tensor [source]#
Perform lexicographical sorting of multiple tensors, considering each tensor as a key.
This function sorts the given tensors lexicographically, where sorting is performed by the first key, then by the second key in case of ties in the first key, and so on. It works similarly to NumPy’s
lexsort
, but is designed for PyTorch tensors.- Parameters:
keys – A list of tensors to be sorted, where each tensor represents a sorting key. All tensors must have the same length along the specified dimension (
dim
).dim – The dimension along which to perform the sorting. Defaults to -1 (the last dimension).
- Returns:
A tensor containing indices that will sort the input tensors lexicographically. These indices indicate the order of elements in the sorted tensors.
Example
key1 = torch.tensor([1, 3, 2]) key2 = torch.tensor([9, 7, 8]) sorted_indices = lexsort([key1, key2]) # sorted_indices will contain the indices that sort first by key2, # and then by key1 in case of ties.
- Note:
You can use
torch.unbind
to split the tensor into list.
- evox.utils.jit_fix_operator.nanmin(input_tensor: torch.Tensor, dim: int = -1, keepdim: bool = False)[source]#
Compute the minimum of a tensor along a specified dimension, ignoring NaN values.
This function replaces
NaN
values in the input tensor withinfinity
, and then computes the minimum over the specified dimension, effectively ignoringNaN
values.- Parameters:
input_tensor – The input tensor, which may contain
NaN
values. It can be of any shape.dim – The dimension along which to compute the minimum. Default is
-1
, which corresponds to the last dimension.keepdim – Whether to retain the reduced dimension in the result. Default is
False
. IfTrue
, the output tensor will have the same number of dimensions as the input, with the size of the reduced dimension set to 1.
- Returns:
A named tuple with two fields:
values
(torch.Tensor
): A tensor containing the minimum values computed along the specified dimension, ignoringNaN
values.indices
(torch.Tensor
): A tensor containing the indices of the minimum values along the specified dimension.
The returned tensors
values
andindices
will have the same shape as the input tensor, except for the dimension(s) over which the operation was performed.
Example
x = torch.tensor([[1.0, 2.0], [float('nan'), 4.0]]) result = nanmin(x, dim=0) print(result.values) # Output: tensor([1.0, 2.0]) print(result.indices) # Output: tensor([0, 0])
Note
NaN
values are ignored by replacing them withinfinity
before computing the minimum.If all values along a dimension are
NaN
, the result will beinfinity
for that dimension, and the index will be returned as the first valid index.
- evox.utils.jit_fix_operator.nanmax(input_tensor: torch.Tensor, dim: int = -1, keepdim: bool = False)[source]#
Compute the maximum of a tensor along a specified dimension, ignoring NaN values.
This function replaces
NaN
values in the input tensor with-infinity
, and then computes the maximum over the specified dimension, effectively ignoringNaN
values.- Parameters:
input_tensor – The input tensor, which may contain
NaN
values. It can be of any shape.dim – The dimension along which to compute the maximum. Default is
-1
, which corresponds to the last dimension.keepdim – Whether to retain the reduced dimension in the result. Default is
False
. IfTrue
, the output tensor will have the same number of dimensions as the input, with the size of the reduced dimension set to 1.
- Returns:
A named tuple with two fields:
values
(torch.Tensor
): A tensor containing the maximum values computed along the specified dimension, ignoringNaN
values.indices
(torch.Tensor
): A tensor containing the indices of the maximum values along the specified dimension.
The returned tensors
values
andindices
will have the same shape as the input tensor, except for the dimension(s) over which the operation was performed.
Example
x = torch.tensor([[1.0, 2.0], [float('nan'), 4.0]]) result = nanmax(x, dim=0) print(result.values) # Output: tensor([1.0, 4.0]) print(result.indices) # Output: tensor([0, 1])
Note
NaN
values are ignored by replacing them with-infinity
before computing the maximum.If all values along a dimension are
NaN
, the result will be-infinity
for that dimension, and the index will be returned as the first valid index.