evox.utils.op_register

模块内容

函数

_default_vmap_wrap_inputs

_register_vmap_level

register_vmap_op

将函数注册为带有(可选)矢量化映射(vmap)支持的自定义操作符。此函数是 torch.library.custom_op 的简单包装器,更多信息请参见 PyTorch Custom Optorch.library.custom_op

数据

T

API

evox.utils.op_register.T

TypeVar(...)

class evox.utils.op_register.FnCallable[源代码]

基础: typing.Protocol[evox.utils.op_register.T]

__call__(*args: evox.utils.op_register.T) evox.utils.op_register.T | Tuple[evox.utils.op_register.T, ...][源代码]
class evox.utils.op_register.VmapFnCallable[源代码]

基础: typing.Protocol[evox.utils.op_register.T]

__call__(info: torch._functorch.autograd_function.VmapInfo, in_dims: Tuple[int | None, ...], *args: evox.utils.op_register.T) evox.utils.op_register.T | Tuple[evox.utils.op_register.T, ...][源代码]
class evox.utils.op_register.VmapWrapInputsCallable[源代码]

基础: typing.Protocol[evox.utils.op_register.T]

__call__(info: torch._functorch.autograd_function.VmapInfo, in_dims: Tuple[int | None, ...], *args: evox.utils.op_register.T) Tuple[evox.utils.op_register.T, ...][源代码]
evox.utils.op_register._default_vmap_wrap_inputs(info: torch._functorch.autograd_function.VmapInfo, in_dims: Tuple[int | None, ...], *args)[源代码]
evox.utils.op_register._register_vmap_level(name: str, vmap_fn, fake_vmap_fn, vmap_wrap_inputs, registered, vmap_out_dims, kwargs)[源代码]
evox.utils.op_register.register_vmap_op(fn: evox.utils.op_register.FnCallable | None = None, /, *, fake_fn: evox.utils.op_register.FnCallable | None = None, vmap_fn: evox.utils.op_register.VmapFnCallable | None = None, fake_vmap_fn: evox.utils.op_register.VmapFnCallable | None = None, vmap_wrap_inputs: evox.utils.op_register.VmapWrapInputsCallable | None = None, vmap_out_dims: int | None | Tuple[int | None, ...] = 0, max_vmap_level: int | None = None, name: str = None, mutates_args: str | Sequence[str] = (), device_types: str | Sequence[str] | None = None, schema: str | None = None)[源代码]

将函数注册为带有(可选)矢量化映射(vmap)支持的自定义操作符。此函数是 torch.library.custom_op 的简单包装器,更多信息请参见 PyTorch Custom Optorch.library.custom_op

参数:
  • fn -- 要注册的操作函数。

  • fake_fn -- 假的(抽象评估)函数,用于注册到 fn。

  • vmap_fn -- 将 vmap 函数注册到 fn。默认值为 None,表示不支持 vmap

  • fake_vmap_fn -- 假(抽象评估)vmap函数,用于注册到vmap_fn。如果vmap_fn为None,则忽略;否则不能为None。

  • vmap_wrap_inputs -- 处理 vmap_fn 输入的函数。如果 vmap_fn 为 None,则忽略。默认值 None 将被 _default_vmap_wrap_inputs 替代,该函数会将所有输入的 vmap 维度移动到第一维度(包括 pytree 的叶子部分),并在最前面添加额外的广播维度(如果不存在 vmap 维度)。

  • vmap_out_dims -- vmap_fn 的输出 vmap 维度。如果 vmap_fn 为 None,则忽略。

  • max_vmap_level -- 支持的最大 vmap 等级。默认值 None 表示如果 vmap_fn 为 None,则没有 vmap 等级;如果 vmap_fn 不为 None,则为 1。

  • name -- 运算符的名称。默认值为 None,将被替换为 "evox::custom_op" + fn.name

  • mutates_args -- 函数会修改的参数名称。这个必须是准确的,否则行为将是未定义的。请参阅 torch.library.custom_op 了解更多信息。

  • device_types -- The device types that the operator supports. See torch.library.custom_op for more information.

  • schema -- 操作符的结构。有关更多信息,请参见 torch.library.custom_op。

Example

def _fake_eval(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    return a.new_empty(b.size())

def _fake_vmap_eval(a: torch.Tensor, b: torch.Tensor) -> Tuple[torch.Tensor, int]:
    return _fake_eval(a, b)

def _vmap_eval(a: torch.Tensor, b: torch.Tensor) -> Tuple[torch.Tensor, int]:
    return a * b.sum(dim=1, keepdim=True)

@register_vmap_op(fake_fn=_fake_eval, vmap_fn=_vmap_eval)
def _custom_eval(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    return a * b.sum()