evox.utils.jit_fix_operator#

Module Contents#

Functions#

switch

Element-wise switch select operator that generates a tensor from a list of tensors based on the label tensor.

clamp

Clamp the values of the input tensor a to be within the given lower (lb) and upper (ub) bounds.

clamp_float

Clamp the float values of the input tensor a to be within the given lower (lb) and upper (ub) bounds.

clamp_int

Clamp the int values of the input tensor a to be within the given lower (lb) and upper (ub) bounds.

clip

Clip the values of the input tensor a to be within the range [0, 1].

maximum

Element-wise maximum of two input tensors a and b.

minimum

Element-wise minimum of two input tensors a and b.

maximum_float

Element-wise maximum of input tensor a and float b.

minimum_float

Element-wise minimum of input tensor a and float b.

maximum_int

Element-wise maximum of input tensor a and int b.

minimum_int

Element-wise minimum of input tensor a and int b.

lexsort

Perform lexicographical sorting of multiple tensors, considering each tensor as a key.

nanmin

Compute the minimum of a tensor along a specified dimension, ignoring NaN values.

nanmax

Compute the maximum of a tensor along a specified dimension, ignoring NaN values.

API#

evox.utils.jit_fix_operator.switch(label: torch.Tensor, values: List[torch.Tensor]) torch.Tensor[source]#

Element-wise switch select operator that generates a tensor from a list of tensors based on the label tensor.

Parameters:
  • label – A tensor containing labels used to select from the list of tensors. Must be broadcastable to the shape of rest arguments.

  • values – A list of tensors from which one is selected based on the label. All tensors in the list must be broadcastable to the same shape.

Returns:

A tensor where each element is selected from the list of tensors based on the corresponding element in the label tensor.

evox.utils.jit_fix_operator.clamp(a: torch.Tensor, lb: torch.Tensor, ub: torch.Tensor) torch.Tensor[source]#

Clamp the values of the input tensor a to be within the given lower (lb) and upper (ub) bounds.

This function ensures that each element of the tensor a is not less than the corresponding element of lb and not greater than the corresponding element of ub.

Notice

  1. This is a fix function for torch.clamp since it is not supported in JIT operator fusion.

  2. This is NOT a precise replication of torch.clamp if a, lb or ub is a float tensor and may suffer from numerical precision losses. Please use torch.clamp instead if a precise clamp is required.

Parameters:
  • a – The input tensor to be clamped.

  • lb – The lower bound tensor. Must be broadcastable to the shape of a.

  • ub – The upper bound tensor. Must be broadcastable to the shape of a.

Returns:

A tensor where each element is clamped to be within the specified bounds.

evox.utils.jit_fix_operator.clamp_float(a: torch.Tensor, lb: float, ub: float) torch.Tensor[source]#

Clamp the float values of the input tensor a to be within the given lower (lb) and upper (ub) bounds.

This function ensures that each element of the tensor a is not less than lb and not greater than ub.

Notice

  1. This is a fix function for torch.clamp since it is not supported in JIT operator fusion.

  2. This is NOT a precise replication of torch.clamp if a is a float tensor and may suffer from numerical precision losses. Please use torch.clamp instead if a precise clamp is required.

Parameters:
  • a – The input tensor to be clamped.

  • lb – The lower bound value. Each element of a will be clamped to be not less than lb.

  • ub – The upper bound value. Each element of a will be clamped to be not greater than ub.

Returns:

A tensor where each element is clamped to be within the specified bounds.

evox.utils.jit_fix_operator.clamp_int(a: torch.Tensor, lb: int, ub: int) torch.Tensor[source]#

Clamp the int values of the input tensor a to be within the given lower (lb) and upper (ub) bounds.

This function ensures that each element of the tensor a is not less than lb and not greater than ub.

Notice

  1. This is a fix function for torch.clamp since it is not supported in JIT operator fusion.

  2. This is NOT a precise replication of torch.clamp if a is a int tensor and may suffer from numerical precision losses. Please use torch.clamp instead if a precise clamp is required.

Parameters:
  • a – The input tensor to be clamped.

  • lb – The lower bound value. Each element of a will be clamped to be not less than lb.

  • ub – The upper bound value. Each element of a will be clamped to be not greater than ub.

Returns:

A tensor where each element is clamped to be within the specified bounds.

evox.utils.jit_fix_operator.clip(a: torch.Tensor) torch.Tensor[source]#

Clip the values of the input tensor a to be within the range [0, 1].

Notice: This function invokes clamp(a, 0, 1).

Parameters:

a – The input tensor to be clipped.

Returns:

A tensor where each element is clipped to be within [0, 1].

evox.utils.jit_fix_operator.maximum(a: torch.Tensor, b: torch.Tensor) torch.Tensor[source]#

Element-wise maximum of two input tensors a and b.

Notice: This is a fix function for [torch.maximum](https://pytorch.org/docs/stable/generated/torch.maximum.html] since it is not supported in JIT operator fusion.

Parameters:
  • a – The first input tensor.

  • b – The second input tensor.

Returns:

The element-wise maximum of a and b.

evox.utils.jit_fix_operator.minimum(a: torch.Tensor, b: torch.Tensor) torch.Tensor[source]#

Element-wise minimum of two input tensors a and b.

Notice: This is a fix function for [torch.minimum](https://pytorch.org/docs/stable/generated/torch.minimum.html] since it is not supported in JIT operator fusion.

Parameters:
  • a – The first input tensor.

  • b – The second input tensor.

Returns:

The element-wise minimum of a and b.

evox.utils.jit_fix_operator.maximum_float(a: torch.Tensor, b: float) torch.Tensor[source]#

Element-wise maximum of input tensor a and float b.

Notice: This is a fix function for [torch.maximum](https://pytorch.org/docs/stable/generated/torch.maximum.html] since it is not supported in JIT operator fusion.

Parameters:
  • a – The first input tensor.

  • b – The second input float, which is a scalar value.

Returns:

The element-wise maximum of a and b.

evox.utils.jit_fix_operator.minimum_float(a: torch.Tensor, b: float) torch.Tensor[source]#

Element-wise minimum of input tensor a and float b.

Notice: This is a fix function for torch.minimum since it is not supported in JIT operator fusion.

Parameters:
  • a – The first input tensor.

  • b – The second input float, which is a scalar value.

Returns:

The element-wise minimum of a and b.

evox.utils.jit_fix_operator.maximum_int(a: torch.Tensor, b: int) torch.Tensor[source]#

Element-wise maximum of input tensor a and int b.

Notice: This is a fix function for [torch.maximum](https://pytorch.org/docs/stable/generated/torch.maximum.html] since it is not supported in JIT operator fusion.

Parameters:
  • a – The first input tensor.

  • b – The second input int, which is a scalar value.

Returns:

The element-wise maximum of a and b.

evox.utils.jit_fix_operator.minimum_int(a: torch.Tensor, b: int) torch.Tensor[source]#

Element-wise minimum of input tensor a and int b.

Notice: This is a fix function for torch.minimum since it is not supported in JIT operator fusion.

Parameters:
  • a – The first input tensor.

  • b – The second input int, which is a scalar value.

Returns:

The element-wise minimum of a and b.

evox.utils.jit_fix_operator.lexsort(keys: List[torch.Tensor], dim: int = -1) torch.Tensor[source]#

Perform lexicographical sorting of multiple tensors, considering each tensor as a key.

This function sorts the given tensors lexicographically, where sorting is performed by the first key, then by the second key in case of ties in the first key, and so on. It works similarly to NumPy’s lexsort, but is designed for PyTorch tensors.

Parameters:
  • keys – A list of tensors to be sorted, where each tensor represents a sorting key. All tensors must have the same length along the specified dimension (dim).

  • dim – The dimension along which to perform the sorting. Defaults to -1 (the last dimension).

Returns:

A tensor containing indices that will sort the input tensors lexicographically. These indices indicate the order of elements in the sorted tensors.

Example

key1 = torch.tensor([1, 3, 2])
key2 = torch.tensor([9, 7, 8])
sorted_indices = lexsort([key1, key2])
# sorted_indices will contain the indices that sort first by key2,
# and then by key1 in case of ties.
Note:

You can use torch.unbind to split the tensor into list.

evox.utils.jit_fix_operator.nanmin(input_tensor: torch.Tensor, dim: int = -1, keepdim: bool = False)[source]#

Compute the minimum of a tensor along a specified dimension, ignoring NaN values.

This function replaces NaN values in the input tensor with infinity , and then computes the minimum over the specified dimension, effectively ignoring NaN values.

Parameters:
  • input_tensor – The input tensor, which may contain NaN values. It can be of any shape.

  • dim – The dimension along which to compute the minimum. Default is -1, which corresponds to the last dimension.

  • keepdim – Whether to retain the reduced dimension in the result. Default is False. If True, the output tensor will have the same number of dimensions as the input, with the size of the reduced dimension set to 1.

Returns:

A named tuple with two fields:

  • values (torch.Tensor): A tensor containing the minimum values computed along the specified dimension, ignoring NaN values.

  • indices (torch.Tensor): A tensor containing the indices of the minimum values along the specified dimension.

The returned tensors values and indices will have the same shape as the input tensor, except for the dimension(s) over which the operation was performed.

Example

x = torch.tensor([[1.0, 2.0], [float('nan'), 4.0]])
result = nanmin(x, dim=0)
print(result.values)  # Output: tensor([1.0, 2.0])
print(result.indices)  # Output: tensor([0, 0])

Note

  • NaN values are ignored by replacing them with infinity before computing the minimum.

  • If all values along a dimension are NaN, the result will be infinity for that dimension, and the index will be returned as the first valid index.

evox.utils.jit_fix_operator.nanmax(input_tensor: torch.Tensor, dim: int = -1, keepdim: bool = False)[source]#

Compute the maximum of a tensor along a specified dimension, ignoring NaN values.

This function replaces NaN values in the input tensor with -infinity, and then computes the maximum over the specified dimension, effectively ignoring NaN values.

Parameters:
  • input_tensor – The input tensor, which may contain NaN values. It can be of any shape.

  • dim – The dimension along which to compute the maximum. Default is -1, which corresponds to the last dimension.

  • keepdim – Whether to retain the reduced dimension in the result. Default is False. If True, the output tensor will have the same number of dimensions as the input, with the size of the reduced dimension set to 1.

Returns:

A named tuple with two fields:

  • values (torch.Tensor): A tensor containing the maximum values computed along the specified dimension, ignoring NaN values.

  • indices (torch.Tensor): A tensor containing the indices of the maximum values along the specified dimension.

The returned tensors values and indices will have the same shape as the input tensor, except for the dimension(s) over which the operation was performed.

Example

x = torch.tensor([[1.0, 2.0], [float('nan'), 4.0]])
result = nanmax(x, dim=0)
print(result.values)  # Output: tensor([1.0, 4.0])
print(result.indices)  # Output: tensor([0, 1])

Note

  • NaN values are ignored by replacing them with -infinity before computing the maximum.

  • If all values along a dimension are NaN, the result will be -infinity for that dimension, and the index will be returned as the first valid index.