evox.algorithms.so.es_variants.des 源代码

import torch
import torch.nn.functional as F

from evox.core import Algorithm, Mutable, Parameter


[文档] class DES(Algorithm): """The implementation of the DES algorithm. Reference: Discovering Evolution Strategies via Meta-Black-Box Optimization (https://arxiv.org/abs/2211.11260) This code has been inspired by or utilizes the algorithmic implementation from evosax. More information about evosax can be found at the following URL: GitHub Link: https://github.com/RobertTLange/evosax """ def __init__( self, pop_size: int, center_init: torch.Tensor, temperature: float = 12.5, sigma_init: float = 0.1, device: torch.device | None = None, ): """Initialize the DES algorithm with the given parameters. :param pop_size: The size of the population. :param center_init: The initial center of the population. Must be a 1D tensor. :param temperature: The temperature parameter for the softmax. Defaults to 12.5. :param sigma_init: The initial standard deviation of the noise. Defaults to 0.1. :param device: The device to use for the tensors. Defaults to None. """ super().__init__() assert pop_size > 1 dim = center_init.shape[0] # set hyperparameters self.temperature = Parameter(temperature, device=device) self.sigma_init = Parameter(sigma_init, device=device) self.lrate_mean = Parameter(1.0, device=device) self.lrate_sigma = Parameter(0.1, device=device) # set value ranks = torch.arange(pop_size, device=device) / (pop_size - 1) - 0.5 self.dim = dim self.ranks = ranks self.pop_size = pop_size # setup center_init = center_init.to(device=device) self.center = Mutable(center_init) self.sigma = Mutable(sigma_init * torch.ones(self.dim, device=device))
[文档] def step(self): """Step the DES algorithm by sampling the population, evaluating the fitness, and updating the center.""" device = self.center.device noise = torch.randn(self.pop_size, self.dim, device=device) population = self.center + noise * self.sigma fitness = self.evaluate(population) population = population[fitness.argsort()] weight = F.softmax(-20 * F.sigmoid(self.temperature * self.ranks), dim=0) weight = torch.tile(weight[:, None], (1, self.dim)) weight_mean = (weight * population).sum(dim=0) weight_sigma = torch.sqrt((weight * (population - self.center) ** 2).sum(dim=0) + 1e-6) center = self.center + self.lrate_mean * (weight_mean - self.center) sigma = self.sigma + self.lrate_sigma * (weight_sigma - self.sigma) self.center = center self.sigma = sigma
[文档] def record_step(self): return { "center": self.center, "sigma": self.sigma, }