evox.utils.parameters_and_vector#

Module Contents#

Classes#

ParamsAndVector

The class to convert (batched) parameters dictionary to vector(s) and vice versa.

API#

class evox.utils.parameters_and_vector.ParamsAndVector(dummy_model: torch.nn.Module)[source]#

Bases: evox.core.ModuleBase

The class to convert (batched) parameters dictionary to vector(s) and vice versa.

Initialization

Initialize the ParamsAndVector instance.

Parameters:

dummy_model – A PyTorch model whose parameters will be used to initialize the parameter and vector conversion attributes. Must be an initialized PyTorch model.

to_vector(params: Dict[str, torch.nn.Parameter]) torch.Tensor[source]#

Convert the input parameters dictionary to a single vector.

Parameters:

params – The input parameters dictionary.

Returns:

The output vector obtained by concatenating the flattened parameters.

batched_to_vector(batched_params: Dict[str, torch.nn.Parameter]) torch.Tensor[source]#

Convert a batched parameters dictionary to a batch of vectors.

The input dictionary values must be batched parameters, i.e., they must have the same shape at the first dimension.

Parameters:

batched_params – The input batched parameters dictionary.

Returns:

The output vectors obtained by concatenating the flattened batched parameters. The first dimension of the output vector corresponds to the batch size.

to_params(vector: torch.Tensor) Dict[str, torch.nn.Parameter][source]#

Convert a vector back to a parameters dictionary.

Parameters:

vector – The input vector representing flattened model parameters.

Returns:

The reconstructed parameters dictionary.

batched_to_params(vectors: torch.Tensor) Dict[str, torch.nn.Parameter][source]#

Convert a batch of vectors back to a batched parameters dictionary.

Parameters:

vectors – The input batch of vectors representing flattened model parameters. The first dimension of the tensor corresponds to the batch size.

Returns:

The reconstructed batched parameters dictionary whose tensors’ first dimensions correspond to the batch size.

forward(x: torch.Tensor) Dict[str, torch.nn.Parameter][source]#

The forward function for the ParamsAndVector module is an alias of batched_to_params to cope with StdWorkflow.