evox.utils.op_register
#
模块内容#
类#
函数#
Register a function as a custom operator with (optional) vectorized-map (vmap) support.
This function is a simple wrapper around |
数据#
API#
- evox.utils.op_register.T#
TypeVar(...)
- class evox.utils.op_register.FnCallable[源代码]#
Bases:
typing.Protocol
[evox.utils.op_register.T
]
- class evox.utils.op_register.VmapFnCallable[源代码]#
Bases:
typing.Protocol
[evox.utils.op_register.T
]
- class evox.utils.op_register.VmapWrapInputsCallable[源代码]#
Bases:
typing.Protocol
[evox.utils.op_register.T
]
- evox.utils.op_register._default_vmap_wrap_inputs(info: torch._functorch.autograd_function.VmapInfo, in_dims: Tuple[int | None, ...], *args)[源代码]#
- evox.utils.op_register._register_vmap_level(name: str, vmap_fn, fake_vmap_fn, vmap_wrap_inputs, registered, vmap_out_dims, kwargs)[源代码]#
- evox.utils.op_register.register_vmap_op(fn: evox.utils.op_register.FnCallable | None = None, /, *, fake_fn: evox.utils.op_register.FnCallable | None = None, vmap_fn: evox.utils.op_register.VmapFnCallable | None = None, fake_vmap_fn: evox.utils.op_register.VmapFnCallable | None = None, vmap_wrap_inputs: evox.utils.op_register.VmapWrapInputsCallable | None = None, vmap_out_dims: int | None | Tuple[int | None, ...] = 0, max_vmap_level: int | None = None, name: str = None, mutates_args: str | Sequence[str] = (), device_types: str | Sequence[str] | None = None, schema: str | None = None)[源代码]#
Register a function as a custom operator with (optional) vectorized-map (vmap) support. This function is a simple wrapper around
torch.library.custom_op
, see PyTorch Custom Op andtorch.library.custom_op
for more information.- 参数:
fn -- The operator function to register.
fake_fn -- The fake (abstract evaluation) function to register to
fn
.vmap_fn -- The vmap function to register to
fn
. Default None means no vmap support.fake_vmap_fn -- The fake (abstract evaluation) vmap function to register to
vmap_fn
. Ignored ifvmap_fn
is None; cannot be None otherwise.vmap_wrap_inputs -- The function to deal with inputs for
vmap_fn
. Ignored ifvmap_fn
is None. Default None will be replaced by_default_vmap_wrap_inputs
, which moves all inputs's vmap dimensions to the first dimensions (including pytree leafs), and adds additional broadcast dimensions at the beginning if no vmap dimension is present.vmap_out_dims -- The outputs' vmap dimensions of
vmap_fn
. Ignored ifvmap_fn
is None.max_vmap_level -- The maximum vmap level to support. Default None means no vmap level if
vmap_fn
is None, or 1 ifvmap_fn
is not None.name -- The name of the operator. Default None will be replaced by
"evox::_custom_op_" + fn.__name__
.mutates_args -- The names of args that the function mutates. This MUST be accurate, otherwise, the behavior is undefined. See
torch.library.custom_op
for more information.device_types -- The device types that the operator supports. See
torch.library.custom_op
for more information.schema -- The schema of the operator. See
torch.library.custom_op
for more information.
Example
def _fake_eval(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return a.new_empty(b.size()) def _fake_vmap_eval(a: torch.Tensor, b: torch.Tensor) -> Tuple[torch.Tensor, int]: return _fake_eval(a, b) def _vmap_eval(a: torch.Tensor, b: torch.Tensor) -> Tuple[torch.Tensor, int]: return a * b.sum(dim=1, keepdim=True) @register_vmap_op(fake_fn=_fake_eval, vmap_fn=_vmap_eval) def _custom_eval(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return a * b.sum()