Source code for evox.algorithms.mo.nsga3

from typing import Callable, Optional

import torch

from evox.core import Algorithm, Mutable, vmap
from evox.operators.crossover import simulated_binary
from evox.operators.mutation import polynomial_mutation
from evox.operators.sampling import uniform_sampling
from evox.operators.selection import non_dominate_rank, tournament_selection_multifit
from evox.utils import clamp


[docs] def _get_table_row_inner(bool_ref_candidate: torch.Tensor, upper_bound: torch.Tensor): true_indices = torch.where( bool_ref_candidate, torch.arange(bool_ref_candidate.size(0), dtype=torch.int32, device=torch.get_default_device()), upper_bound, ) true_indices = torch.sort(true_indices, dim=0).values return true_indices.to(torch.int32)
vmap_get_table_row = vmap( _get_table_row_inner, in_dims=(0, None), )
[docs] def _select_from_index_by_min_inner( group_id: torch.Tensor, group_dist: torch.Tensor, idx: torch.Tensor, ): min_idx = torch.argmin(torch.where(group_id == idx.unsqueeze(0), group_dist, torch.inf)).to(torch.int32) return min_idx
vmap_select_from_index_by_min = vmap( _select_from_index_by_min_inner, in_dims=(None, None, 0), )
[docs] def _get_extreme_inner(norm_fit: torch.Tensor, w: torch.Tensor): return torch.argmin(torch.max(norm_fit / w.unsqueeze(0), dim=1).values)
vmap_get_extreme = vmap( _get_extreme_inner, in_dims=(None, 0), )
[docs] class NSGA3(Algorithm): """ An implementation of the tensorized NSGA-III for many-objective optimization problems. :references: [1] K. Deb and H. Jain, "An Evolutionary Many-Objective Optimization Algorithm Using Reference-Point-Based Nondominated Sorting Approach, Part I: Solving Problems With Box Constraints," IEEE Transactions on Evolutionary Computation, vol. 18, no. 4, pp. 577-601, 2014. Available: https://ieeexplore.ieee.org/document/6600851 [2] H. Li, Z. Liang, and R. Cheng, "GPU-accelerated Evolutionary Many-objective Optimization Using Tensorized NSGA-III," in 2025 IEEE Congress on Evolutionary Computation, 2025. """ def __init__( self, pop_size: int, n_objs: int, lb: torch.Tensor, ub: torch.Tensor, selection_op: Optional[Callable] = None, mutation_op: Optional[Callable] = None, crossover_op: Optional[Callable] = None, data_type: Optional[torch.dtype] = None, device: torch.device | None = None, ): """Initializes the NSGA-III algorithm. :param pop_size: The size of the population. :param n_objs: The number of objective functions in the optimization problem. :param lb: The lower bounds for the decision variables (1D tensor). :param ub: The upper bounds for the decision variables (1D tensor). :param selection_op: The selection operation for evolutionary strategy (optional). :param mutation_op: The mutation operation, defaults to `polynomial_mutation` if not provided (optional). :param crossover_op: The crossover operation, defaults to `simulated_binary` if not provided (optional). :param data_type: The data type for the decision variables (optional). Defaults to torch.float32. :param device: The device on which computations should run (optional). Defaults to PyTorch's default device. """ super().__init__() self.pop_size = pop_size self.n_objs = n_objs if device is None: device = torch.get_default_device() # check assert lb.shape == ub.shape and lb.ndim == 1 and ub.ndim == 1 assert lb.dtype == ub.dtype and lb.device == ub.device self.dim = lb.shape[0] # write to self self.lb = lb.to(device=device) self.ub = ub.to(device=device) self.selection = selection_op self.mutation = mutation_op self.crossover = crossover_op if self.selection is None: self.selection = tournament_selection_multifit if self.mutation is None: self.mutation = polynomial_mutation if self.crossover is None: self.crossover = simulated_binary if data_type == torch.bool: population = torch.rand(self.pop_size, self.dim, device=device) population = population > 0.5 else: length = ub - lb population = torch.rand(self.pop_size, self.dim, device=device) population = length * population + lb self.pop = Mutable(population) self.fit = Mutable(torch.full((self.pop_size, self.n_objs), torch.inf, device=device)) self.rank = Mutable(torch.full((self.pop_size,), torch.inf, device=device)) self.ref = uniform_sampling(self.pop_size, self.n_objs)[0]
[docs] def init_step(self): """ Perform the initialization step of the workflow. Calls the `init_step` of the algorithm if overwritten; otherwise, its `step` method will be invoked. """ self.fit = self.evaluate(self.pop) self.rank = non_dominate_rank(self.fit)
[docs] def step(self): """Perform the optimization step of the workflow.""" mating_pool = self.selection(self.pop_size, [self.rank]) crossovered = self.crossover(self.pop[mating_pool]) offspring = self.mutation(crossovered, self.lb, self.ub) offspring = clamp(offspring, self.lb, self.ub) off_fit = self.evaluate(offspring) merge_pop = torch.cat([self.pop, offspring], dim=0) merge_fit = torch.cat([self.fit, off_fit], dim=0) shuffled_idx = torch.randperm(merge_pop.shape[0]) merge_pop = merge_pop[shuffled_idx] merge_fit = merge_fit[shuffled_idx] rank = non_dominate_rank(merge_fit) worst_rank = torch.topk(rank, self.pop_size + 1, largest=False)[0][-1] candi_idx = torch.where(rank <= worst_rank)[0] merge_pop = merge_pop[candi_idx] merge_fit = merge_fit[candi_idx] rank = rank[candi_idx] device = self.pop.device # Normalize ideal_point = torch.min(merge_fit, dim=0)[0] norm_fit = merge_fit - ideal_point weight = torch.eye(self.n_objs, device=device) + 1e-6 ex_idx = vmap_get_extreme(norm_fit, weight) extreme = norm_fit[ex_idx] if torch.linalg.matrix_rank(extreme) == self.n_objs: hyperplane = torch.linalg.solve(extreme, torch.ones(self.n_objs, device=device)) intercepts = 1.0 / hyperplane else: intercepts = torch.max(norm_fit, dim=0).values norm_fit = norm_fit / intercepts.unsqueeze(0) shuffled_idx = torch.randperm(self.ref.shape[0]) ref = self.ref[shuffled_idx] # Calculate distances by cosine similarity distances = self._compute_distances(norm_fit, ref) # Associate each solution with its nearest reference point group_dist, group_id = torch.min(distances, dim=1) # count the number of individuals for each group id selected_group_id = group_id[rank < worst_rank] rho = torch.bincount(selected_group_id, minlength=ref.shape[0]).to(torch.int32) selected_num = torch.sum(rho, dtype=torch.int32) candi_group_id = group_id[rank == worst_rank] rho_last = torch.bincount(candi_group_id, minlength=ref.shape[0]).to(torch.int32) upper_bound = torch.tensor( merge_pop.shape[0] + merge_pop.shape[1] + merge_fit.shape[1] + 1, dtype=torch.int32, device=device ) rho = torch.where(rho_last == 0, upper_bound, rho) group_id = torch.where(rank == worst_rank, group_id, upper_bound).to(torch.int32) row_indices = torch.arange(ref.shape[0], device=device).to(torch.int32) # first selection stage rho_level = 0 _selected_ref = rho == rho_level selected_ref = torch.where(_selected_ref, row_indices, upper_bound) candi_idx = vmap_select_from_index_by_min(group_id, group_dist, selected_ref) rank[candi_idx[_selected_ref]] = worst_rank - 1 rho_last = torch.where(_selected_ref, rho_last - 1, rho_last) rho = torch.where(_selected_ref, rho_level + 1, rho) rho = torch.where(rho_last == 0, upper_bound, rho) selected_num += torch.sum(_selected_ref) # second selection stage group_id[candi_idx[_selected_ref]] = upper_bound bool_ref_candidates = row_indices[:, None] == group_id[None, :] ref_candidates = vmap_get_table_row(bool_ref_candidates, upper_bound) ref_cand_idx = torch.zeros_like(rho) while selected_num < self.pop_size: rho_level = torch.min(rho) _selected_ref = rho == rho_level candi_idx = ref_candidates[row_indices, ref_cand_idx] rank[candi_idx[_selected_ref]] = worst_rank - 1 ref_cand_idx = torch.where(_selected_ref, ref_cand_idx + 1, ref_cand_idx) rho_last = torch.where(_selected_ref, rho_last - 1, rho_last) rho = torch.where(_selected_ref, rho_level + 1, rho) rho = torch.where(rho_last == 0, upper_bound, rho) selected_num += torch.sum(_selected_ref) # truncate to pop_size dif = selected_num - self.pop_size candi_idx = torch.where(_selected_ref, candi_idx, upper_bound) sorted_index = torch.sort(candi_idx, stable=False)[0] rank[sorted_index[:dif]] = worst_rank # get final pop and fit self.pop = merge_pop[rank < worst_rank] self.fit = merge_fit[rank < worst_rank] self.rank = rank[rank < worst_rank]
[docs] def _get_extreme(self, norm_fit: torch.Tensor, w: torch.Tensor): return torch.argmin(torch.max(norm_fit / w.unsqueeze(0), dim=1).values)
[docs] def _compute_distances(self, fit: torch.Tensor, ref: torch.Tensor): # Normalize solutions and reference points to unit vectors fit_magnitude = torch.norm(fit, dim=1, keepdim=True).clamp_min(1e-10) fit_norm = fit / fit_magnitude ref_norm = ref / torch.norm(ref, dim=1, keepdim=True).clamp_min(1e-10) # Compute cosine similarity (dot product of normalized vectors) cosine_sim = torch.matmul(fit_norm, ref_norm.T) # Compute the angular distance component (sqrt(1 - cosine_similarity^2)) angular_distance = torch.sqrt((1 - cosine_sim**2).clamp_min(1e-10)) # Compute the final distance by multiplying magnitude with angular distance distances = fit_magnitude * angular_distance return distances