evox.utils.parameters_and_vector
#
模块内容#
类#
将(批处理)参数字典转换为向量的类,反之亦然。 |
API#
- class evox.utils.parameters_and_vector.ParamsAndVector(dummy_model: torch.nn.Module)[源代码]#
Bases:
evox.core.ModuleBase
将(批处理)参数字典转换为向量的类,反之亦然。
初始化
初始化 ParamsAndVector 实例。
- 参数:
dummy_model -- 一个 PyTorch 模型,其参数将用于初始化参数和向量转换属性。必须是已初始化的 PyTorch 模型。
- to_vector(params: Dict[str, torch.nn.Parameter]) torch.Tensor [源代码]#
将输入参数字典转换为单个向量。
- 参数:
params -- 输入参数字典。
- 返回:
通过连接展平的参数获得的输出向量。
- batched_to_vector(batched_params: Dict[str, torch.nn.Parameter]) torch.Tensor [源代码]#
将批量参数字典转换为向量的批量。
输入字典的值必须是批处理参数,即它们在第一维度上必须具有相同的形状。
- 参数:
batched_params -- 输入批处理参数字典。
- 返回:
通过连接展平的批处理参数获得的输出向量。输出向量的第一个维度对应于批处理大小。
- to_params(vector: torch.Tensor) Dict[str, torch.nn.Parameter] [源代码]#
将向量转换回参数字典。
- 参数:
vector -- 表示展平模型参数的输入向量。
- 返回:
重建的参数字典。