evox.core._vmap_fix
#
Module Contents#
Functions#
Unwraps a batched tensor into its original tensor and the batch dimensions/sizes. |
|
Wraps a original tensor into its batched form with given batch dimensions. |
|
Generate a batched tensor of random values. |
|
Generate a batched tensor of random values with the same shape as the given tensor. |
|
Aligns a tensor with the batching dimensions of a current batched tensor. |
|
Prints a formatted string with one positional tensor used for debugging inside JIT traced functions on-the-fly. |
Data#
API#
- evox.core._vmap_fix.unwrap_batch_tensor(tensor: torch.Tensor) Tuple[torch.Tensor, Tuple[int, ...], Tuple[int, ...]] [source]#
Unwraps a batched tensor into its original tensor and the batch dimensions/sizes.
- Parameters:
tensor – The batched tensor to be unwrapped.
- Returns:
A tuple of the original tensor, the batch dimensions, and the batch sizes.
- evox.core._vmap_fix.wrap_batch_tensor(tensor: torch.Tensor, in_dims: int | Tuple[int, ...]) torch.Tensor [source]#
Wraps a original tensor into its batched form with given batch dimensions.
- Parameters:
tensor – The original tensor to be wrapped.
in_dims – The batch dimension(s).
- Returns:
The batched tensor.
- evox.core._vmap_fix._get_batched_size(in_dim: int | Tuple[int, ...], original: torch.Tensor) int | Tuple[int, ...] | None [source]#
- evox.core._vmap_fix._vmap_batch_sizes: contextvars.ContextVar[List[int]]#
‘ContextVar(…)’
- evox.core._vmap_fix._flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)[source]#
- evox.core._vmap_fix.batched_random(rand_func: Callable, *size: Tuple[int | torch.SymInt], **kwargs) torch.Tensor [source]#
Generate a batched tensor of random values.
Given a random function (e.g.
torch.randn
,torch.rand
, etc.) and its size arguments, this function generates a batched tensor of random values by applying the given function to the size extended with the current vmap batch size.- Parameters:
rand_func – A function that generates a tensor of random values.
*size – The size arguments to the given function.
**kwargs – The keyword arguments to the given function.
- Returns:
The batched tensor of random values.
Usage:
rand1 = batched_random(torch.rand, 2, 3, device=device) rand2 = batched_random(torch.randn, 4, device=device, dtype=torch.float32) rand3 = batched_random(torch.randint, 5, 6, low=0, high=10, device=device, dtype=torch.float32)
- evox.core._vmap_fix.batched_random_like(rand_func: Callable, like_tensor: torch.Tensor, **kwargs) torch.Tensor [source]#
Generate a batched tensor of random values with the same shape as the given tensor.
Given a random function (e.g.
torch.randn_like
,torch.rand_like
, etc.) and a tensor, this function generates a batched tensor of random values by applying the given function to the tensor extended with the current vmap batch size.- Parameters:
rand_func – A function that generates a tensor of random values.
like_tensor – The tensor to generate random values like.
**kwargs – The keyword arguments to the given function.
- Returns:
The batched tensor of random values.
- evox.core._vmap_fix._original_size#
None
- evox.core._vmap_fix._original_rand#
None
- evox.core._vmap_fix._original_randn#
None
- evox.core._vmap_fix._original_randint#
None
- evox.core._vmap_fix._original_randperm#
None
- evox.core._vmap_fix._original_rand_like#
None
- evox.core._vmap_fix._original_randn_like#
None
- evox.core._vmap_fix._original_randint_like#
None
- evox.core._vmap_fix._original_get_item#
None
- evox.core._vmap_fix._original_set_item#
None
- evox.core._vmap_fix._batch_fixing: contextvars.ContextVar[bool]#
‘ContextVar(…)’
- evox.core._vmap_fix.align_vmap_tensor(value: Any, current_value: Any | None) torch.Tensor [source]#
Aligns a tensor with the batching dimensions of a current batched tensor.
This function adjusts the input tensor
value
to match the batch dimensions ofcurrent_value
, which is assumed to be a batched tensor. Ifvalue
is already a batched tensor orcurrent_value
is not a batched tensor, it returnsvalue
unchanged.- Parameters:
value – The tensor to be aligned. If not a
torch.Tensor
, it is returned unchanged.current_value – The reference batched tensor. If
None
or not a batched tensor,value
is returned unchanged.
- Returns:
The input
value
aligned with the batch dimensions ofcurrent_value
, if applicable.
- evox.core._vmap_fix.debug_print(format: str, arg: torch.Tensor) torch.Tensor [source]#
Prints a formatted string with one positional tensor used for debugging inside JIT traced functions on-the-fly.
When vectorized-mapping, it unwraps the batched tensor to print the underlying values. Otherwise, the function behaves like
format.format(*args, **kwargs)
.- Parameters:
format – A string format.
arg – The positional tensor.
- Returns:
The unchanged tensor.