__all__ = [
"Algorithm",
"Problem",
"Workflow",
"Monitor",
]
from abc import ABC
from typing import Any, Dict
import torch
from evox.core.module import ModuleBase
[docs]
class Algorithm(ModuleBase, ABC):
"""Base class for all algorithms"""
def __init__(self):
super().__init__()
[docs]
def step(self) -> None:
"""Execute the algorithm procedure for one step."""
pass
[docs]
def init_step(self) -> None:
"""Initialize the algorithm and execute the algorithm procedure for the first step."""
self.step()
[docs]
def final_step(self) -> None:
"""Execute the algorithm procedure for the final step."""
self.step()
[docs]
def evaluate(self, pop: torch.Tensor) -> torch.Tensor:
"""Evaluate the fitness at given points.
This function is a proxy function of `Problem.evaluate` set by workflow.
By default, this functions raises `NotImplementedError`.
:param pop: The population.
:return: The fitness.
"""
raise NotImplementedError(
"Evaluate function is not implemented. It is a proxy function of `Problem.evaluate` set by workflow."
)
[docs]
def record_step(self) -> None:
"""Record the current step."""
return {"pop": self.pop, "fit": self.fit}
[docs]
class Problem(ModuleBase, ABC):
"""Base class for all problems"""
def __init__(self):
super().__init__()
[docs]
def evaluate(self, pop: torch.Tensor) -> torch.Tensor:
"""Evaluate the fitness at given points
:param pop: The population.
:return: The fitness.
## Notice
If this function contains external evaluations that cannot be compiled by `torch.compile`, please wrap it with `torch.compiler.disable` or use `evox.utils.register_vmap_op` to register external functions as operators.
"""
return torch.empty(0)
[docs]
class Workflow(ModuleBase, ABC):
"""The base class for workflow."""
[docs]
def init_step(self) -> None:
"""Perform the first optimization step of the workflow."""
return self.step()
[docs]
def step(self) -> None:
"""The basic function to step a workflow."""
pass
[docs]
def final_step(self) -> None:
"""Perform the final optimization step of the workflow."""
return self.step()
class Agent(ModuleBase, ABC):
"""The base class for agents,
representing an individual entity with its own state (memory) and behavior (act method).
The state of the agent is stored in `self`, the act method defines how the agent interacts with its environment."""
def act(self, *args, **kwargs) -> Any:
"""The function to act according to the observation.
:param observation: The observation.
:return: The action.
"""
pass
[docs]
class Monitor(ModuleBase, ABC):
"""
The monitor base class.
Monitors are used to monitor the evolutionary process.
They contains a set of callbacks,
which will be called at specific points during the execution of the workflow.
Monitor itself lives outside the main workflow, so jit is not required.
To implements a monitor, implement your own callbacks and override the hooks method.
The hooks method should return a list of strings, which are the names of the callbacks.
Currently the supported callbacks are:
`post_ask`, `pre_eval`, `post_eval`, and `pre_tell`.
"""
[docs]
def set_config(self, **config) -> "Monitor":
"""Set the static variables according to `config`.
:param config: The configuration.
:return: This module.
"""
return self
[docs]
def record_auxiliary(self, aux: Dict[str, torch.Tensor]) -> None:
"""Record the auxiliary information.
:param aux: The auxiliary information.
"""
pass
[docs]
def post_ask(self, candidate_solution: torch.Tensor) -> None:
"""The hook function to be executed before the solution transformation.
:param candidate_solution: The population (candidate solutions) before the solution transformation.
"""
pass
[docs]
def pre_eval(self, transformed_candidate_solution: Any) -> None:
"""The hook function to be executed after the solution transformation.
:param transformed_candidate_solution: The population (candidate solutions) after the solution transformation.
"""
pass
[docs]
def post_eval(self, fitness: torch.Tensor) -> None:
"""The hook function to be executed before the fitness transformation.
:param fitness: The fitnesses before the fitness transformation.
"""
pass
[docs]
def pre_tell(self, transformed_fitness: torch.Tensor) -> None:
"""The hook function to be executed after the fitness transformation.
:param transformed_fitness: The fitnesses after the fitness transformation.
"""
pass