Source code for evox.workflows.std_workflow

__all__ = ["StdWorkflow"]


from typing import Any

import torch

from evox.core import Algorithm, Monitor, Problem, Workflow


class OptDirectionTransform(torch.nn.Module):
    def __init__(self, opt_direction):
        super().__init__()
        self.opt_direction = opt_direction

    def forward(self, fitness: torch.Tensor) -> torch.Tensor:
        return fitness * self.opt_direction


[docs] class StdWorkflow(Workflow): """The standard workflow. ## Usage: ``` algo = BasicAlgorithm(10) prob = BasicProblem() class solution_transform(nn.Module): def forward(self, x: torch.Tensor): return x / 5 class fitness_transform(nn.Module): def forward(self, f: torch.Tensor): return -f monitor = EvalMonitor(full_sol_history=True) workflow = StdWorkflow( algo, prob, monitor=monitor, solution_transform=solution_transform(), fitness_transform=fitness_transform(), ) workflow.init_step() print(monitor.get_topk_fitness()) workflow.step() print(monitor.get_topk_fitness()) # run rest of the steps ... ``` """ def __init__( self, algorithm: Algorithm, problem: Problem, monitor: Monitor | None = None, opt_direction: str | list[str] = "min", solution_transform: torch.nn.Module | None = None, fitness_transform: torch.nn.Module | None = None, device: str | torch.device | int | None = None, enable_distributed: bool = False, group: Any = None, ): """Initialize the standard workflow with static arguments. :param algorithm: The algorithm to be used in the workflow. :param problem: The problem to be used in the workflow. :param monitor: The monitors to be used in the workflow. Defaults to None. :param opt_direction: The optimization direction, can only be "min" or "max". Defaults to "min". If "max", the fitness will be negated prior to `fitness_transform` and monitor. :param solution_transform: The solution transformation function. MUST be compile-compatible module/function. Defaults to None. :param fitness_transforms: The fitness transformation function. MUST be compile-compatible module/function. Defaults to None. :param device: The device of the workflow. Defaults to None. :param enable_distributed: Whether to enable distributed workflow. Defaults to False. :param group: The group name used in the distributed workflow. Defaults to None. ```{note} The `algorithm`, `problem`, `solution_transform`, and `fitness_transform` will be IN-PLACE moved to the device specified by `device`. ``` ```{note} The `opt_direction` parameter determines the optimization direction. Since EvoX algorithms are designed to minimize by default, setting `opt_direction="max"` will cause the fitness values to be negated before being passed to `fitness_transform` and the monitor. ``` """ super().__init__() if device is None: device = torch.get_default_device() if isinstance(opt_direction, str): assert opt_direction in [ "min", "max", ], f"Expect optimization direction to be `min` or `max`, got {opt_direction}" self.opt_direction = torch.tensor(1 if opt_direction == "min" else -1, device=device) elif isinstance(opt_direction, list): assert all(d in ["min", "max"] for d in opt_direction), ( f"Expect optimization direction to be `min` or `max`, got {opt_direction}" ) self.opt_direction = torch.tensor([1 if d == "min" else -1 for d in opt_direction], device=device) # transform if solution_transform is None: solution_transform = torch.nn.Identity() if fitness_transform is None: fitness_transform = torch.nn.Identity() fitness_transform = torch.nn.Sequential(OptDirectionTransform(self.opt_direction), fitness_transform) assert callable(solution_transform), f"Expect solution transform to be callable, got {solution_transform}" assert callable(fitness_transform), f"Expect fitness transform to be callable, got {fitness_transform}" if isinstance(solution_transform, torch.nn.Module): solution_transform.to(device=device) if isinstance(fitness_transform, torch.nn.Module): fitness_transform.to(device=device) if monitor is None: monitor = Monitor() else: monitor.set_config(opt_direction=self.opt_direction) algorithm.to(device=device) monitor.to(device=device) problem.to(device=device) # set algorithm evaluate self._has_init_ = type(algorithm).init_step != Algorithm.init_step self._has_final_ = type(algorithm).final_step != Algorithm.final_step class _SubAlgorithm(type(algorithm)): def __init__(self_algo): super(Algorithm, self_algo).__init__() self_algo.__dict__.update(algorithm.__dict__) def evaluate(self_algo, pop: torch.Tensor) -> torch.Tensor: return self._evaluate(pop) # set submodules self.algorithm = _SubAlgorithm() self.monitor = monitor self.problem = problem self.solution_transform = solution_transform self.fitness_transform = fitness_transform self.enable_distributed = enable_distributed self.group = group
[docs] def get_submodule(self, target: str) -> Any: return super().get_submodule(target)
[docs] def _evaluate(self, population: torch.Tensor) -> torch.Tensor: self.monitor.post_ask(population) if self.enable_distributed: rank = torch.distributed.get_rank(group=self.group) pop_size = population.size(0) world_size = torch.distributed.get_world_size(group=self.group) rank = torch.distributed.get_rank(group=self.group) population = population.tensor_split(world_size, dim=0)[rank] population = self.solution_transform(population) self.monitor.pre_eval(population) if self.enable_distributed: # When using distributed, we need to make sure that the random number generator is forked. # Otherwise, since the evaluation process for different individuals is not independent, # the random number generator could get polluted. with torch.random.fork_rng(): fitness = self.problem.evaluate(population) # construct a list of tensors to gather all fitness all_fitness = torch.zeros(pop_size, *fitness.shape[1:], device=fitness.device, dtype=fitness.dtype) all_fitness = list(all_fitness.tensor_split(world_size, dim=0)) # gather all fitness torch.distributed.all_gather(all_fitness, fitness, group=self.group) fitness = torch.cat(all_fitness, dim=0) else: fitness = self.problem.evaluate(population) self.monitor.post_eval(fitness) fitness = self.fitness_transform(fitness) self.monitor.pre_tell(fitness) return fitness
[docs] def _step(self, init: bool = False, final: bool = False): if init and self._has_init_: self.algorithm.init_step() elif final and self._has_final_: self.algorithm.final_step() else: self.algorithm.step() # If the monitor has override the `record_auxiliary` method, it will be called here. if "record_auxiliary" in self.monitor.__class__.__dict__: self.monitor.record_auxiliary(self.algorithm.record_step())
[docs] def init_step(self): """ Perform the first optimization step of the workflow. Calls the `init_step` of the algorithm if overwritten; otherwise, its `step` method will be invoked. """ self._step(init=True, final=False)
[docs] def final_step(self): """ Perform the last optimization step of the workflow. Calls the `final_step` of the algorithm if overwritten; otherwise, its `step` method will be invoked. """ self._step(init=False, final=True)
[docs] def step(self): """Perform a single optimization step using the algorithm and the problem.""" self._step()