evox.utils.jit_fix_operator

Module Contents

Functions

switch

Element-wise switch select operator that generates a tensor from a list of tensors based on the label tensor.

clamp

Clamp the values of the input tensor a to be within the given lower (lb) and upper (ub) bounds.

clamp_float

Clamp the float values of the input tensor a to be within the given lower (lb) and upper (ub) bounds.

clamp_int

Clamp the int values of the input tensor a to be within the given lower (lb) and upper (ub) bounds.

clip

Clip the values of the input tensor a to be within the range [0, 1].

maximum

Element-wise maximum of two input tensors a and b.

minimum

Element-wise minimum of two input tensors a and b.

maximum_float

Element-wise maximum of input tensor a and float b.

minimum_float

Element-wise minimum of input tensor a and float b.

maximum_int

Element-wise maximum of input tensor a and int b.

minimum_int

Element-wise minimum of input tensor a and int b.

lexsort

Perform lexicographical sorting of multiple tensors, considering each tensor as a key.

nanmin

Compute the minimum of a tensor along a specified dimension, ignoring NaN values.

nanmax

Compute the maximum of a tensor along a specified dimension, ignoring NaN values.

randint

Randomly generate a tensor of integers within a specified range like torch.randint. However, the low and high values are now tensors.

API

evox.utils.jit_fix_operator.switch(label: torch.Tensor, values: List[torch.Tensor]) torch.Tensor[source]

Element-wise switch select operator that generates a tensor from a list of tensors based on the label tensor.

Parameters:
  • label – A tensor containing labels used to select from the list of tensors. Must be broadcastable to the shape of rest arguments.

  • values – A list of tensors from which one is selected based on the label. All tensors in the list must be broadcastable to the same shape.

Returns:

A tensor where each element is selected from the list of tensors based on the corresponding element in the label tensor.

evox.utils.jit_fix_operator.clamp(a: torch.Tensor, lb: torch.Tensor, ub: torch.Tensor) torch.Tensor[source]

Clamp the values of the input tensor a to be within the given lower (lb) and upper (ub) bounds.

This function ensures that each element of the tensor a is not less than the corresponding element of lb and not greater than the corresponding element of ub.

Notice

This is a fix function for torch.clamp since it is not supported in JIT operator fusion yet.

Warning

This is NOT a precise replication of torch.clamp if a, lb or ub is a float tensor and may suffer from numerical precision losses. Please use torch.clamp instead if a precise clamp is required.

Parameters:
  • a – The input tensor to be clamped.

  • lb – The lower bound tensor. Must be broadcastable to the shape of a.

  • ub – The upper bound tensor. Must be broadcastable to the shape of a.

Returns:

A tensor where each element is clamped to be within the specified bounds.

evox.utils.jit_fix_operator.clamp_float(a: torch.Tensor, lb: float, ub: float) torch.Tensor[source]

Clamp the float values of the input tensor a to be within the given lower (lb) and upper (ub) bounds.

This function ensures that each element of the tensor a is not less than lb and not greater than ub.

Notice

This is a fix function for torch.clamp since it is not supported in JIT operator fusion yet.

Warning

This is NOT a precise replication of torch.clamp if a is a float tensor and may suffer from numerical precision losses. Please use torch.clamp instead if a precise clamp is required.

Parameters:
  • a – The input tensor to be clamped.

  • lb – The lower bound value. Each element of a will be clamped to be not less than lb.

  • ub – The upper bound value. Each element of a will be clamped to be not greater than ub.

Returns:

A tensor where each element is clamped to be within the specified bounds.

evox.utils.jit_fix_operator.clamp_int(a: torch.Tensor, lb: int, ub: int) torch.Tensor[source]

Clamp the int values of the input tensor a to be within the given lower (lb) and upper (ub) bounds.

This function ensures that each element of the tensor a is not less than lb and not greater than ub.

Notice

This is a fix function for torch.clamp since it is not supported in JIT operator fusion yet.

Warning

This is NOT a precise replication of torch.clamp if a is a float tensor and may suffer from numerical precision losses. Please use torch.clamp instead if a precise clamp is required.

Parameters:
  • a – The input tensor to be clamped.

  • lb – The lower bound value. Each element of a will be clamped to be not less than lb.

  • ub – The upper bound value. Each element of a will be clamped to be not greater than ub.

Returns:

A tensor where each element is clamped to be within the specified bounds.

evox.utils.jit_fix_operator.clip(a: torch.Tensor) torch.Tensor[source]

Clip the values of the input tensor a to be within the range [0, 1].

Notice: This function invokes clamp(a, 0, 1).

Parameters:

a – The input tensor to be clipped.

Returns:

A tensor where each element is clipped to be within [0, 1].

evox.utils.jit_fix_operator.maximum(a: torch.Tensor, b: torch.Tensor) torch.Tensor[source]

Element-wise maximum of two input tensors a and b.

Notice

This is a fix function for [torch.maximum](https://pytorch.org/docs/stable/generated/torch.maximum.html] since it is not supported in JIT operator fusion yet.

Warning

This is NOT a precise replication of torch.maximum if a or b is a float tensor and may suffer from numerical precision losses. Please use torch.maximum instead if a precise maximum is required.

Parameters:
  • a – The first input tensor.

  • b – The second input tensor.

Returns:

The element-wise maximum of a and b.

evox.utils.jit_fix_operator.minimum(a: torch.Tensor, b: torch.Tensor) torch.Tensor[source]

Element-wise minimum of two input tensors a and b.

Notice

This is a fix function for [torch.minimum](https://pytorch.org/docs/stable/generated/torch.minimum.html] since it is not supported in JIT operator fusion yet.

Warning

This is NOT a precise replication of torch.minimum if a or b is a float tensor and may suffer from numerical precision losses. Please use torch.minimum instead if a precise minimum is required.

Parameters:
  • a – The first input tensor.

  • b – The second input tensor.

Returns:

The element-wise minimum of a and b.

evox.utils.jit_fix_operator.maximum_float(a: torch.Tensor, b: float) torch.Tensor[source]

Element-wise maximum of input tensor a and float b.

Notice

This is a fix function for [torch.maximum](https://pytorch.org/docs/stable/generated/torch.maximum.html] since it is not supported in JIT operator fusion yet.

Warning

This is NOT a precise replication of torch.maximum if a is a float tensor and may suffer from numerical precision losses. Please use torch.maximum instead if a precise maximum is required.

Parameters:
  • a – The first input tensor.

  • b – The second input float, which is a scalar value.

Returns:

The element-wise maximum of a and b.

evox.utils.jit_fix_operator.minimum_float(a: torch.Tensor, b: float) torch.Tensor[source]

Element-wise minimum of input tensor a and float b.

Notice

This is a fix function for [torch.minimum](https://pytorch.org/docs/stable/generated/torch.minimum.html] since it is not supported in JIT operator fusion yet.

Warning

This is NOT a precise replication of torch.minimum if a is a float tensor and may suffer from numerical precision losses. Please use torch.minimum instead if a precise minimum is required.

Parameters:
  • a – The first input tensor.

  • b – The second input float, which is a scalar value.

Returns:

The element-wise minimum of a and b.

evox.utils.jit_fix_operator.maximum_int(a: torch.Tensor, b: int) torch.Tensor[source]

Element-wise maximum of input tensor a and int b.

Notice

This is a fix function for [torch.maximum](https://pytorch.org/docs/stable/generated/torch.maximum.html] since it is not supported in JIT operator fusion yet.

Warning

This is NOT a precise replication of torch.maximum if a is a float tensor and may suffer from numerical precision losses. Please use torch.maximum instead if a precise maximum is required.

Parameters:
  • a – The first input tensor.

  • b – The second input int, which is a scalar value.

Returns:

The element-wise maximum of a and b.

evox.utils.jit_fix_operator.minimum_int(a: torch.Tensor, b: int) torch.Tensor[source]

Element-wise minimum of input tensor a and int b.

Notice

This is a fix function for [torch.minimum](https://pytorch.org/docs/stable/generated/torch.minimum.html] since it is not supported in JIT operator fusion yet.

Warning

This is NOT a precise replication of torch.minimum if a is a float tensor and may suffer from numerical precision losses. Please use torch.minimum instead if a precise minimum is required.

Parameters:
  • a – The first input tensor.

  • b – The second input int, which is a scalar value.

Returns:

The element-wise minimum of a and b.

evox.utils.jit_fix_operator.lexsort(keys: List[torch.Tensor], dim: int = -1) torch.Tensor[source]

Perform lexicographical sorting of multiple tensors, considering each tensor as a key.

This function sorts the given tensors lexicographically, where sorting is performed by the first key, then by the second key in case of ties in the first key, and so on. It works similarly to NumPy’s lexsort, but is designed for PyTorch tensors.

Parameters:
  • keys – A list of tensors to be sorted, where each tensor represents a sorting key. All tensors must have the same length along the specified dimension (dim).

  • dim – The dimension along which to perform the sorting. Defaults to -1 (the last dimension).

Returns:

A tensor containing indices that will sort the input tensors lexicographically. These indices indicate the order of elements in the sorted tensors.

Example

key1 = torch.tensor([1, 3, 2])
key2 = torch.tensor([9, 7, 8])
sorted_indices = lexsort([key1, key2])
# sorted_indices will contain the indices that sort first by key2,
# and then by key1 in case of ties.

Tip

You can use torch.unbind to split the tensor into list.

evox.utils.jit_fix_operator.nanmin(input_tensor: torch.Tensor, dim: int = -1, keepdim: bool = False)[source]

Compute the minimum of a tensor along a specified dimension, ignoring NaN values.

This function replaces NaN values in the input tensor with infinity , and then computes the minimum over the specified dimension, effectively ignoring NaN values.

Parameters:
  • input_tensor – The input tensor, which may contain NaN values. It can be of any shape.

  • dim – The dimension along which to compute the minimum. Default is -1, which corresponds to the last dimension.

  • keepdim – Whether to retain the reduced dimension in the result. Default is False. If True, the output tensor will have the same number of dimensions as the input, with the size of the reduced dimension set to 1.

Returns:

A named tuple with two fields:

  • values (torch.Tensor): A tensor containing the minimum values computed along the specified dimension, ignoring NaN values.

  • indices (torch.Tensor): A tensor containing the indices of the minimum values along the specified dimension.

The returned tensors values and indices will have the same shape as the input tensor, except for the dimension(s) over which the operation was performed.

Example

x = torch.tensor([[1.0, 2.0], [float('nan'), 4.0]])
result = nanmin(x, dim=0)
print(result.values)  # Output: tensor([1.0, 2.0])
print(result.indices)  # Output: tensor([0, 0])

Note

  • NaN values are ignored by replacing them with infinity before computing the minimum.

  • If all values along a dimension are NaN, the result will be infinity for that dimension, and the index will be returned as the first valid index.

evox.utils.jit_fix_operator.nanmax(input_tensor: torch.Tensor, dim: int = -1, keepdim: bool = False)[source]

Compute the maximum of a tensor along a specified dimension, ignoring NaN values.

This function replaces NaN values in the input tensor with -infinity, and then computes the maximum over the specified dimension, effectively ignoring NaN values.

Parameters:
  • input_tensor – The input tensor, which may contain NaN values. It can be of any shape.

  • dim – The dimension along which to compute the maximum. Default is -1, which corresponds to the last dimension.

  • keepdim – Whether to retain the reduced dimension in the result. Default is False. If True, the output tensor will have the same number of dimensions as the input, with the size of the reduced dimension set to 1.

Returns:

A named tuple with two fields:

  • values (torch.Tensor): A tensor containing the maximum values computed along the specified dimension, ignoring NaN values.

  • indices (torch.Tensor): A tensor containing the indices of the maximum values along the specified dimension.

The returned tensors values and indices will have the same shape as the input tensor, except for the dimension(s) over which the operation was performed.

Example

x = torch.tensor([[1.0, 2.0], [float('nan'), 4.0]])
result = nanmax(x, dim=0)
print(result.values)  # Output: tensor([1.0, 4.0])
print(result.indices)  # Output: tensor([0, 1])

Note

  • NaN values are ignored by replacing them with -infinity before computing the maximum.

  • If all values along a dimension are NaN, the result will be -infinity for that dimension, and the index will be returned as the first valid index.

evox.utils.jit_fix_operator.randint(low: torch.Tensor | int | torch.SymInt, high: torch.Tensor | int | torch.SymInt, size: Sequence[int | torch.SymInt] | torch.Size, dtype: torch.dtype | None = None, device: torch.device | None = None, generator: torch.Generator | None = None)[source]

Randomly generate a tensor of integers within a specified range like torch.randint. However, the low and high values are now tensors.

This function first generate a uniform random tensor of floats with range [0, 1), and then adjust the range with given low and high values.

Parameters:
  • low – The input lower bound tensor (inclusive) or int. It must be a scalar.

  • high – The input upper bound tensor (exclusive) or int. It must be a scalar.

  • size – The desired size of the output tensor.

  • dtype – The desired data type of the output tensor. Default None means the same as low or high.

  • device – The desired device for the output tensor. Default None means the same as low or high.

  • generator – The random number generator to use. Default is None.

Returns:

A random tensor of integers within the specified range of given size.

Example

high = torch.tensor(8)
randint(0, high, (2, 2))

Note

When used with torch.compile and the the low, high or size can be dynamic integers (e.g. size of another tensor), this function MUST be used with torch.compile(..., dynamic=False).