evox.utils.op_register#

模块内容#

#

函数#

_default_vmap_wrap_inputs

_register_vmap_level

register_vmap_op

Register a function as a custom operator with (optional) vectorized-map (vmap) support. This function is a simple wrapper around torch.library.custom_op, see PyTorch Custom Op and torch.library.custom_op for more information.

数据#

T

API#

evox.utils.op_register.T#

TypeVar(...)

class evox.utils.op_register.FnCallable[源代码]#

Bases: typing.Protocol[evox.utils.op_register.T]

__call__(*args: evox.utils.op_register.T) evox.utils.op_register.T | Tuple[evox.utils.op_register.T, ...][源代码]#
class evox.utils.op_register.VmapFnCallable[源代码]#

Bases: typing.Protocol[evox.utils.op_register.T]

__call__(info: torch._functorch.autograd_function.VmapInfo, in_dims: Tuple[int | None, ...], *args: evox.utils.op_register.T) evox.utils.op_register.T | Tuple[evox.utils.op_register.T, ...][源代码]#
class evox.utils.op_register.VmapWrapInputsCallable[源代码]#

Bases: typing.Protocol[evox.utils.op_register.T]

__call__(info: torch._functorch.autograd_function.VmapInfo, in_dims: Tuple[int | None, ...], *args: evox.utils.op_register.T) Tuple[evox.utils.op_register.T, ...][源代码]#
evox.utils.op_register._default_vmap_wrap_inputs(info: torch._functorch.autograd_function.VmapInfo, in_dims: Tuple[int | None, ...], *args)[源代码]#
evox.utils.op_register._register_vmap_level(name: str, vmap_fn, fake_vmap_fn, vmap_wrap_inputs, registered, vmap_out_dims, kwargs)[源代码]#
evox.utils.op_register.register_vmap_op(fn: evox.utils.op_register.FnCallable | None = None, /, *, fake_fn: evox.utils.op_register.FnCallable | None = None, vmap_fn: evox.utils.op_register.VmapFnCallable | None = None, fake_vmap_fn: evox.utils.op_register.VmapFnCallable | None = None, vmap_wrap_inputs: evox.utils.op_register.VmapWrapInputsCallable | None = None, vmap_out_dims: int | None | Tuple[int | None, ...] = 0, max_vmap_level: int | None = None, name: str = None, mutates_args: str | Sequence[str] = (), device_types: str | Sequence[str] | None = None, schema: str | None = None)[源代码]#

Register a function as a custom operator with (optional) vectorized-map (vmap) support. This function is a simple wrapper around torch.library.custom_op, see PyTorch Custom Op and torch.library.custom_op for more information.

参数:
  • fn -- The operator function to register.

  • fake_fn -- The fake (abstract evaluation) function to register to fn.

  • vmap_fn -- The vmap function to register to fn. Default None means no vmap support.

  • fake_vmap_fn -- The fake (abstract evaluation) vmap function to register to vmap_fn. Ignored if vmap_fn is None; cannot be None otherwise.

  • vmap_wrap_inputs -- The function to deal with inputs for vmap_fn. Ignored if vmap_fn is None. Default None will be replaced by _default_vmap_wrap_inputs, which moves all inputs's vmap dimensions to the first dimensions (including pytree leafs), and adds additional broadcast dimensions at the beginning if no vmap dimension is present.

  • vmap_out_dims -- The outputs' vmap dimensions of vmap_fn. Ignored if vmap_fn is None.

  • max_vmap_level -- The maximum vmap level to support. Default None means no vmap level if vmap_fn is None, or 1 if vmap_fn is not None.

  • name -- The name of the operator. Default None will be replaced by "evox::_custom_op_" + fn.__name__.

  • mutates_args -- The names of args that the function mutates. This MUST be accurate, otherwise, the behavior is undefined. See torch.library.custom_op for more information.

  • device_types -- The device types that the operator supports. See torch.library.custom_op for more information.

  • schema -- The schema of the operator. See torch.library.custom_op for more information.

Example

def _fake_eval(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    return a.new_empty(b.size())

def _fake_vmap_eval(a: torch.Tensor, b: torch.Tensor) -> Tuple[torch.Tensor, int]:
    return _fake_eval(a, b)

def _vmap_eval(a: torch.Tensor, b: torch.Tensor) -> Tuple[torch.Tensor, int]:
    return a * b.sum(dim=1, keepdim=True)

@register_vmap_op(fake_fn=_fake_eval, vmap_fn=_vmap_eval)
def _custom_eval(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    return a * b.sum()