evox.core._vmap_fix#

Module Contents#

Functions#

_set_func_id

unwrap_batch_tensor

Unwraps a batched tensor into its original tensor and the batch dimensions/sizes.

wrap_batch_tensor

Wraps a original tensor into its batched form with given batch dimensions.

_get_batched_size

get_vmap_batch_sizes

vmap_increment_nesting

_flat_vmap

batched_random

Generate a batched tensor of random values.

batched_random_like

Generate a batched tensor of random values with the same shape as the given tensor.

_batch_size

_batch_rand

_batch_randn

_batch_randint

_batch_randperm

_batch_rand_like

_batch_randn_like

_batch_randint_like

_batch_getitem

_batch_setitem

use_batch_fixing

align_vmap_tensor

Aligns a tensor with the batching dimensions of a current batched tensor.

_debug_print

debug_print

Prints a formatted string with one positional tensor used for debugging inside JIT traced functions on-the-fly.

Data#

API#

evox.core._vmap_fix._set_func_id(new_func, old_func)[source]#
evox.core._vmap_fix.unwrap_batch_tensor(tensor: torch.Tensor) Tuple[torch.Tensor, Tuple[int, ...], Tuple[int, ...]][source]#

Unwraps a batched tensor into its original tensor and the batch dimensions/sizes.

Parameters:

tensor – The batched tensor to be unwrapped.

Returns:

A tuple of the original tensor, the batch dimensions, and the batch sizes.

evox.core._vmap_fix.wrap_batch_tensor(tensor: torch.Tensor, in_dims: int | Tuple[int, ...]) torch.Tensor[source]#

Wraps a original tensor into its batched form with given batch dimensions.

Parameters:
  • tensor – The original tensor to be wrapped.

  • in_dims – The batch dimension(s).

Returns:

The batched tensor.

evox.core._vmap_fix._get_batched_size(in_dim: int | Tuple[int, ...], original: torch.Tensor) int | Tuple[int, ...] | None[source]#
evox.core._vmap_fix._vmap_batch_sizes: contextvars.ContextVar[List[int]]#

‘ContextVar(…)’

evox.core._vmap_fix.get_vmap_batch_sizes()[source]#
evox.core._vmap_fix.vmap_increment_nesting(batch_size, randomness)[source]#
evox.core._vmap_fix._flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)[source]#
evox.core._vmap_fix.batched_random(rand_func: Callable, *size: Tuple[int | torch.SymInt], **kwargs) torch.Tensor[source]#

Generate a batched tensor of random values.

Given a random function (e.g. torch.randn, torch.rand, etc.) and its size arguments, this function generates a batched tensor of random values by applying the given function to the size extended with the current vmap batch size.

Parameters:
  • rand_func – A function that generates a tensor of random values.

  • *size – The size arguments to the given function.

  • **kwargs – The keyword arguments to the given function.

Returns:

The batched tensor of random values.

Usage:

rand1 = batched_random(torch.rand, 2, 3, device=device)
rand2 = batched_random(torch.randn, 4, device=device, dtype=torch.float32)
rand3 = batched_random(torch.randint, 5, 6, low=0, high=10, device=device, dtype=torch.float32)
evox.core._vmap_fix.batched_random_like(rand_func: Callable, like_tensor: torch.Tensor, **kwargs) torch.Tensor[source]#

Generate a batched tensor of random values with the same shape as the given tensor.

Given a random function (e.g. torch.randn_like, torch.rand_like, etc.) and a tensor, this function generates a batched tensor of random values by applying the given function to the tensor extended with the current vmap batch size.

Parameters:
  • rand_func – A function that generates a tensor of random values.

  • like_tensor – The tensor to generate random values like.

  • **kwargs – The keyword arguments to the given function.

Returns:

The batched tensor of random values.

evox.core._vmap_fix._original_size#

None

evox.core._vmap_fix._original_rand#

None

evox.core._vmap_fix._original_randn#

None

evox.core._vmap_fix._original_randint#

None

evox.core._vmap_fix._original_randperm#

None

evox.core._vmap_fix._original_rand_like#

None

evox.core._vmap_fix._original_randn_like#

None

evox.core._vmap_fix._original_randint_like#

None

evox.core._vmap_fix._original_get_item#

None

evox.core._vmap_fix._original_set_item#

None

evox.core._vmap_fix._batch_size(tensor: torch.Tensor, dim: int | None = None)[source]#
evox.core._vmap_fix._batch_rand(*size, **kwargs)[source]#
evox.core._vmap_fix._batch_randn(*size, **kwargs)[source]#
evox.core._vmap_fix._batch_randint(low=None, high=None, size=None, **kwargs)[source]#
evox.core._vmap_fix._batch_randperm(n, **kwargs)[source]#
evox.core._vmap_fix._batch_rand_like(like_tensor, **kwargs)[source]#
evox.core._vmap_fix._batch_randn_like(like_tensor, **kwargs)[source]#
evox.core._vmap_fix._batch_randint_like(like_tensor, **kwargs)[source]#
evox.core._vmap_fix._batch_getitem(tensor: torch.Tensor, indices, dim=0)[source]#
evox.core._vmap_fix._batch_setitem(tensor: torch.Tensor, indices, values, dim=0)[source]#
evox.core._vmap_fix._batch_fixing: contextvars.ContextVar[bool]#

‘ContextVar(…)’

evox.core._vmap_fix.use_batch_fixing(new_batch_fixing: bool = True)[source]#
evox.core._vmap_fix.align_vmap_tensor(value: Any, current_value: Any | None) torch.Tensor[source]#

Aligns a tensor with the batching dimensions of a current batched tensor.

This function adjusts the input tensor value to match the batch dimensions of current_value, which is assumed to be a batched tensor. If value is already a batched tensor or current_value is not a batched tensor, it returns value unchanged.

Parameters:
  • value – The tensor to be aligned. If not a torch.Tensor, it is returned unchanged.

  • current_value – The reference batched tensor. If None or not a batched tensor, value is returned unchanged.

Returns:

The input value aligned with the batch dimensions of current_value, if applicable.

evox.core._vmap_fix._debug_print(format: str, arg: torch.Tensor) torch.Tensor[source]#
evox.core._vmap_fix.debug_print(format: str, arg: torch.Tensor) torch.Tensor[source]#

Prints a formatted string with one positional tensor used for debugging inside JIT traced functions on-the-fly.

When vectorized-mapping, it unwraps the batched tensor to print the underlying values. Otherwise, the function behaves like format.format(*args, **kwargs).

Parameters:
  • format – A string format.

  • arg – The positional tensor.

Returns:

The unchanged tensor.