__all__ = ["EvalMonitor"]
import warnings
import weakref
from enum import IntEnum
from typing import Dict, List, NamedTuple, Tuple
import torch
from evox.core import Monitor, Mutable
from evox.operators.selection import non_dominate_rank
from evox.utils import register_vmap_op
try:
from evox.vis_tools import plot
except ImportError:
plot = None
# https://github.com/pytorch/pytorch/issues/36748
def unique(x: torch.Tensor, dim=0):
"""Return the unique elements of the input tensor, as well as the unique index."""
x = x.nan_to_num()
unique, inverse, counts = torch.unique(x, dim=dim, sorted=True, return_inverse=True, return_counts=True)
inv_sorted = inverse.argsort(stable=True)
tot_counts = torch.cat((counts.new_zeros(1), counts.cumsum(dim=0)))[:-1]
index = inv_sorted[tot_counts]
return unique, inverse, counts, index
class HistoryType(IntEnum):
"""History type for the monitor."""
FITNESS = 0
SOLUTION = 1
AUXILIARY = 2
class MonitorHistory(NamedTuple):
fit_history: List[torch.Tensor]
sol_history: List[torch.Tensor]
aux_history: List[torch.Tensor]
__monitor_history__: Dict[int, MonitorHistory] = {}
def _fake_data_sink(monitor_id: int, data: torch.Tensor, data_type: int, token: torch.Tensor) -> torch.Tensor:
return token.new_empty(token.size())
def _fake_vmap_data_sink(
monitor_id: int,
data: torch.Tensor,
data_type: int,
token: torch.Tensor,
) -> torch.Tensor:
return token.new_empty(token.size())
def _vmap_data_sink(
monitor_id: int,
data: torch.Tensor,
data_type: int,
token: torch.Tensor,
) -> torch.Tensor:
__monitor_history__[monitor_id][data_type].append(data)
return token + 1
@register_vmap_op(fake_fn=_fake_data_sink, vmap_fn=_vmap_data_sink, fake_vmap_fn=_fake_vmap_data_sink)
def _data_sink(monitor_id: int, data: torch.Tensor, data_type: int, token: torch.Tensor) -> torch.Tensor:
"""Record the data into the monitor history log.
This function uses the provided token to establish data dependencies between
successive function calls, ensuring proper tracking and ordering of monitored values.
"""
__monitor_history__[monitor_id][data_type].append(data)
return token + 1
[文档]
class EvalMonitor(Monitor):
"""Evaluation monitor.
Used for both single-objective and multi-objective workflow.
Hooked around the evaluation process,
can monitor the offspring, their corresponding fitness and keep track of the evaluation count.
Moreover, it can also record the best solution or the pareto front on-the-fly.
"""
def __init__(
self,
multi_obj: bool = False,
full_fit_history: bool = True,
full_sol_history: bool = False,
full_pop_history: bool = False,
topk: int = 1,
device: torch.device | None = None,
history_device: torch.device | None = None,
):
"""Initialize the monitor.
:param multi_obj: Whether the optimization is multi-objective. Defaults to False.
:param full_fit_history: Whether to record the full history of fitness value. Default to True. Setting it to False may reduce memory usage.
:param full_sol_history: Whether to record the full history of solutions. Default to False. Setting it to True may increase memory usage.
:param topk: Only affect Single-objective optimization. The number of elite solutions to record. Default to 1, which will record the best individual.
:param device: The device of the monitor. Defaults to None.
:param history_device: The device to record the history. Defaults to None. If None, it will use cpu.
```{tip}
Setting the `history_device` to the same device as the monitor will save the data transfer time,
but may increase the memory usage on the device.
```
```{note}
When `opt_direction="max"` is used,
fitness values are internally multiplied by -1 to ensure that optimization logic always treats the best fitness as the minimum value.
As a result, raw fitness values (e.g., `monitor.topk_fitness`, `monitor.fitness_history`, etc.) will appear negated.
However, access methods such as `monitor.get_best_fitness()` and `monitor.get_pf_fitness()` automatically reverse this negation and return the original, unmodified values.
```
"""
super().__init__()
device = torch.get_default_device() if device is None else device
history_device = torch.device("cpu") if history_device is None else history_device
self.multi_obj = multi_obj
self.full_fit_history = full_fit_history
self.full_sol_history = full_sol_history
self.full_pop_history = full_pop_history
self.opt_direction = torch.tensor(1)
self.topk = topk
self.device = device
self.history_device = history_device
self.aux_keys = []
# mutable
self.latest_solution = Mutable(torch.empty(0, device=device))
self.latest_fitness = Mutable(torch.empty(0, device=device))
self.topk_solutions = Mutable(torch.empty(0, device=device))
self.topk_fitness = Mutable(torch.empty(0, device=device))
self._id_ = id(self)
self._token = Mutable(torch.tensor(0, device=device))
__monitor_history__[self._id_] = MonitorHistory([], [], [])
weakref.finalize(
self,
__monitor_history__.pop,
self._id_,
None,
)
@property
def fitness_history(self) -> List[torch.Tensor]:
return __monitor_history__[self._id_][HistoryType.FITNESS]
@property
def fit_history(self) -> List[torch.Tensor]:
# alias for fitness_history
return self.fitness_history
@property
def solution_history(self) -> List[torch.Tensor]:
return __monitor_history__[self._id_][HistoryType.SOLUTION]
@property
def sol_history(self) -> List[torch.Tensor]:
# alias for solution_history
return self.solution_history
@property
def aux_history(self) -> Dict[str, List[torch.Tensor]]:
# alias for auxiliary_history
return self.auxiliary_history
@property
def auxiliary_history(self) -> Dict[str, List[torch.Tensor]]:
raw_aux_history = __monitor_history__[self._id_][HistoryType.AUXILIARY]
n_keys = len(self.aux_keys)
if n_keys == 0:
return {}
assert len(raw_aux_history) % n_keys == 0
aux_history = {}
for i, key in enumerate(self.aux_keys):
aux_history[key] = raw_aux_history[i::n_keys]
return aux_history
[文档]
def set_config(self, **config):
if "multi_obj" in config:
self.multi_obj = config["multi_obj"]
if "full_fit_history" in config:
self.full_fit_history = config["full_fit_history"]
if "full_sol_history" in config:
self.full_sol_history = config["full_sol_history"]
if "topk" in config:
self.topk = config["topk"]
if "opt_direction" in config:
self.opt_direction = config["opt_direction"]
return self
[文档]
def record_auxiliary(self, aux: Dict[str, torch.Tensor]):
if self.full_pop_history:
if len(self.aux_keys) == 0:
self.aux_keys = list(aux.keys())
for key in self.aux_keys:
assert isinstance(aux[key], torch.Tensor)
self._token = _data_sink(
self._id_, aux[key].to(self.history_device, non_blocking=True), HistoryType.AUXILIARY, self._token
)
[文档]
def post_ask(self, candidate_solution: torch.Tensor):
self.latest_solution = candidate_solution
[文档]
def pre_tell(self, fitness: torch.Tensor):
self.latest_fitness = fitness
if fitness.ndim == 1:
# single-objective
self.multi_obj = False
assert fitness.size(0) >= self.topk
if self.topk_solutions.ndim <= 1:
topk_solutions = self.latest_solution
topk_fitness = fitness
rank = torch.topk(topk_fitness, self.topk, largest=False)[1]
self.topk_fitness = topk_fitness[rank]
self.topk_solutions = topk_solutions[rank]
else:
topk_solutions = torch.concatenate([self.topk_solutions, self.latest_solution])
topk_fitness = torch.concatenate([self.topk_fitness, fitness])
rank = torch.topk(topk_fitness, self.topk, largest=False)[1]
self.topk_fitness = topk_fitness[rank]
self.topk_solutions = topk_solutions[rank]
elif fitness.ndim == 2:
# multi-objective
self.multi_obj = True
# In multi-objective, we can't simply take the topk solutions.
# Instead, we need to record the solutions and fitness values.
# And in the end, we can get the pareto front.
else:
raise ValueError(f"Invalid fitness shape: {fitness.shape}")
if self.full_fit_history or self.full_sol_history:
self.record_history()
[文档]
def record_history(self):
if self.full_sol_history:
latest_solution = self.latest_solution.to(self.history_device, non_blocking=True)
assert isinstance(latest_solution, torch.Tensor)
self._token = _data_sink(self._id_, latest_solution, HistoryType.SOLUTION, self._token)
if self.full_fit_history:
latest_fitness = self.latest_fitness.to(self.history_device, non_blocking=True)
assert isinstance(latest_fitness, torch.Tensor)
self._token = _data_sink(self._id_, latest_fitness, HistoryType.FITNESS, self._token)
[文档]
def get_latest_fitness(self) -> torch.Tensor:
"""Get the fitness values from the latest iteration."""
opt_dir = self.opt_direction.to(self.device)
return opt_dir * self.latest_fitness
[文档]
def get_latest_solution(self) -> torch.Tensor:
"""Get the solution from the latest iteration."""
return self.latest_solution
[文档]
def get_topk_fitness(self) -> torch.Tensor:
"""Get the topk fitness values so far."""
opt_dir = self.opt_direction.to(self.device)
return opt_dir * self.topk_fitness
[文档]
def get_topk_solutions(self) -> torch.Tensor:
"""Get the topk solutions so far."""
if self.multi_obj:
raise ValueError("Multi-objective optimization does not have a single best solution. Please use get_pf_solutions")
return self.topk_solutions
[文档]
def get_best_solution(self) -> torch.Tensor:
"""Get the best solution so far."""
if self.multi_obj:
raise ValueError("Multi-objective optimization does not have a single best solution. Please use get_pf_solutions")
return self.topk_solutions[0]
[文档]
def get_best_fitness(self) -> torch.Tensor:
"""Get the best fitness value so far."""
if self.multi_obj:
raise ValueError("Multi-objective optimization does not have a single best fitness. Please use get_pf_fitness")
opt_dir = self.opt_direction.to(self.device)
return opt_dir * self.topk_fitness[0]
[文档]
def get_pf_fitness(self, deduplicate=True) -> torch.Tensor:
"""Get the approximate pareto front fitness values of all the solutions evaluated so far.
Requires enabling `full_fit_history`."""
if not self.multi_obj:
raise ValueError("get_pf_fitness is only available for multi-objective optimization.")
if not self.full_fit_history:
warnings.warn("`get_pf_fitness` requires enabling `full_fit_history`.")
all_fitness = self.fitness_history
all_fitness = torch.cat(all_fitness, dim=0)
if deduplicate:
all_fitness = torch.unique(all_fitness, dim=0)
rank = non_dominate_rank(all_fitness)
pf_fit = all_fitness[rank == 0]
opt_dir = self.opt_direction.to(self.history_device)
return pf_fit * opt_dir
[文档]
def get_pf_solutions(self, deduplicate=True) -> torch.Tensor:
"""Get the approximate pareto front solutions of all the solutions evaluated so far.
Requires enabling both `full_sol_history` and `full_sol_history`.
If `deduplicate` is set to True, the duplicated solutions will be removed."""
if not self.multi_obj:
raise ValueError("get_pf_solutions is only available for multi-objective optimization.")
pf_solutions, _pf_fitness = self.get_pf(deduplicate)
return pf_solutions
[文档]
def get_pf(self, deduplicate=True) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the approximate pareto front solutions and fitness values of all the solutions evaluated so far.
Requires enabling both `full_sol_history` and `full_sol_history`.
If `deduplicate` is set to True, the duplicated solutions will be removed."""
if not self.multi_obj:
raise ValueError("get_pf is only available for multi-objective optimization.")
if not self.full_fit_history or not self.full_sol_history:
warnings.warn("`get_pf` requires enabling both `full_sol_history` and `full_sol_history`.")
all_solutions = self.get_solution_history()
all_solutions = torch.cat(all_solutions, dim=0)
all_fitness = self.fitness_history
all_fitness = torch.cat(all_fitness, dim=0)
if deduplicate:
_, unique_index, _, _ = unique(all_solutions)
all_solutions = all_solutions[unique_index]
all_fitness = all_fitness[unique_index]
rank = non_dominate_rank(all_fitness)
pf_fitness = all_fitness[rank == 0]
pf_solutions = all_solutions[rank == 0]
opt_dir = self.opt_direction.to(self.history_device)
return pf_solutions, pf_fitness * opt_dir
[文档]
def get_fitness_history(self) -> List[torch.Tensor]:
"""Get the full history of fitness values."""
opt_dir = self.opt_direction.to(self.history_device)
return [opt_dir * fit for fit in self.fitness_history]
[文档]
def get_solution_history(self) -> List[torch.Tensor]:
"""Get the full history of solutions."""
return self.solution_history
[文档]
@torch.compiler.disable
def plot(self, problem_pf=None, source="eval", **kwargs):
"""Plot the fitness history.
If the problem's Pareto front is provided, it will be plotted as well.
:param problem_pf: The Pareto front of the problem. Default to None.
:param source: The source of the data, either "eval" or "pop", default to "eval".
When "eval", the fitness from the problem evaluation side will be plotted, representing what the problem sees.
When "pop", the fitness from the population inside the algorithm will be plotted, representing what the algorithm sees.
:param kwargs: Additional arguments for the plot.
"""
if not self.fitness_history and not self.aux_history:
warnings.warn("No fitness history recorded, return None")
return
if plot is None:
warnings.warn('No visualization tool available, return None. Hint: pip install "evox[vis]"')
return
if source == "pop":
fitness_history = self.aux_history["fit"]
elif source == "eval":
fitness_history = self.get_fitness_history()
else:
raise ValueError(f"Invalid source argument: {source}, expect 'eval' or 'pop'.")
fitness_history = [f.cpu().numpy() for f in fitness_history]
if fitness_history[0].ndim == 1:
n_objs = 1
else:
n_objs = self.fitness_history[0].shape[1]
if n_objs == 1:
return plot.plot_obj_space_1d(fitness_history, **kwargs)
elif n_objs == 2:
return plot.plot_obj_space_2d(fitness_history, problem_pf, **kwargs)
elif n_objs == 3:
return plot.plot_obj_space_3d(fitness_history, problem_pf, **kwargs)
else:
warnings.warn("Not supported yet.")