evox.utils.parameters_and_vector#

模块内容#

#

ParamsAndVector

将(批处理)参数字典转换为向量的类,反之亦然。

API#

class evox.utils.parameters_and_vector.ParamsAndVector(dummy_model: torch.nn.Module)[源代码]#

Bases: evox.core.ModuleBase

将(批处理)参数字典转换为向量的类,反之亦然。

初始化

初始化 ParamsAndVector 实例。

参数:

dummy_model -- 一个 PyTorch 模型,其参数将用于初始化参数和向量转换属性。必须是已初始化的 PyTorch 模型。

to_vector(params: Dict[str, torch.nn.Parameter]) torch.Tensor[源代码]#

将输入参数字典转换为单个向量。

参数:

params -- 输入参数字典。

返回:

通过连接展平的参数获得的输出向量。

batched_to_vector(batched_params: Dict[str, torch.nn.Parameter]) torch.Tensor[源代码]#

将批量参数字典转换为向量的批量。

输入字典的值必须是批处理参数,即它们在第一维度上必须具有相同的形状。

参数:

batched_params -- 输入批处理参数字典。

返回:

通过连接展平的批处理参数获得的输出向量。输出向量的第一个维度对应于批处理大小。

to_params(vector: torch.Tensor) Dict[str, torch.nn.Parameter][源代码]#

将向量转换回参数字典。

参数:

vector -- 表示展平模型参数的输入向量。

返回:

重建的参数字典。

batched_to_params(vectors: torch.Tensor) Dict[str, torch.nn.Parameter][源代码]#

将一批向量转换回批量参数字典。

参数:

vectors -- 输入批量向量表示展平的模型参数。张量的第一个维度对应于批量大小。

返回:

重构后的批处理参数字典,其张量的第一个维度对应于批处理大小。

forward(x: torch.Tensor) Dict[str, torch.nn.Parameter][源代码]#

ParamsAndVector 模块的 forward 函数是 batched_to_params 的别名,用于应对 `StdWorkflow