evox.utils.op_register
¶
模块内容¶
类¶
函数¶
将函数注册为带有(可选)矢量化映射(vmap)支持的自定义操作符。此函数是 |
数据¶
API¶
- evox.utils.op_register.T¶
TypeVar(...)
- class evox.utils.op_register.FnCallable[源代码]¶
基础:
typing.Protocol
[evox.utils.op_register.T
]
- class evox.utils.op_register.VmapFnCallable[源代码]¶
基础:
typing.Protocol
[evox.utils.op_register.T
]
- class evox.utils.op_register.VmapWrapInputsCallable[源代码]¶
基础:
typing.Protocol
[evox.utils.op_register.T
]
- evox.utils.op_register._default_vmap_wrap_inputs(info: torch._functorch.autograd_function.VmapInfo, in_dims: Tuple[int | None, ...], *args)[源代码]¶
- evox.utils.op_register._register_vmap_level(name: str, vmap_fn, fake_vmap_fn, vmap_wrap_inputs, registered, vmap_out_dims, kwargs)[源代码]¶
- evox.utils.op_register.register_vmap_op(fn: evox.utils.op_register.FnCallable | None = None, /, *, fake_fn: evox.utils.op_register.FnCallable | None = None, vmap_fn: evox.utils.op_register.VmapFnCallable | None = None, fake_vmap_fn: evox.utils.op_register.VmapFnCallable | None = None, vmap_wrap_inputs: evox.utils.op_register.VmapWrapInputsCallable | None = None, vmap_out_dims: int | None | Tuple[int | None, ...] = 0, max_vmap_level: int | None = None, name: str = None, mutates_args: str | Sequence[str] = (), device_types: str | Sequence[str] | None = None, schema: str | None = None)[源代码]¶
将函数注册为带有(可选)矢量化映射(vmap)支持的自定义操作符。此函数是
torch.library.custom_op
的简单包装器,更多信息请参见 PyTorch Custom Op 和torch.library.custom_op
。- 参数:
fn -- 要注册的操作函数。
fake_fn -- 假的(抽象评估)函数,用于注册到 fn。
vmap_fn -- 将
vmap
函数注册到fn
。默认值为None
,表示不支持vmap
。fake_vmap_fn -- 假(抽象评估)vmap函数,用于注册到vmap_fn。如果vmap_fn为None,则忽略;否则不能为None。
vmap_wrap_inputs -- 处理 vmap_fn 输入的函数。如果 vmap_fn 为 None,则忽略。默认值 None 将被 _default_vmap_wrap_inputs 替代,该函数会将所有输入的 vmap 维度移动到第一维度(包括 pytree 的叶子部分),并在最前面添加额外的广播维度(如果不存在 vmap 维度)。
vmap_out_dims -- vmap_fn 的输出 vmap 维度。如果 vmap_fn 为 None,则忽略。
max_vmap_level -- 支持的最大 vmap 等级。默认值 None 表示如果 vmap_fn 为 None,则没有 vmap 等级;如果 vmap_fn 不为 None,则为 1。
name -- 运算符的名称。默认值为 None,将被替换为 "evox::custom_op" + fn.name。
mutates_args -- 函数会修改的参数名称。这个必须是准确的,否则行为将是未定义的。请参阅 torch.library.custom_op 了解更多信息。
device_types -- The device types that the operator supports. See
torch.library.custom_op
for more information.schema -- 操作符的结构。有关更多信息,请参见 torch.library.custom_op。
Example
def _fake_eval(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return a.new_empty(b.size()) def _fake_vmap_eval(a: torch.Tensor, b: torch.Tensor) -> Tuple[torch.Tensor, int]: return _fake_eval(a, b) def _vmap_eval(a: torch.Tensor, b: torch.Tensor) -> Tuple[torch.Tensor, int]: return a * b.sum(dim=1, keepdim=True) @register_vmap_op(fake_fn=_fake_eval, vmap_fn=_vmap_eval) def _custom_eval(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return a * b.sum()