Source code for evox.algorithms.so.es_variants.nes

import math

import torch

from evox.core import Algorithm, Mutable, Parameter


[docs] class XNES(Algorithm): """The implementation of the xNES algorithm. Reference: Exponential Natural Evolution Strategies (https://dl.acm.org/doi/abs/10.1145/1830483.1830557) """ def __init__( self, init_mean: torch.Tensor, init_covar: torch.Tensor, pop_size: int | None = None, recombination_weights: torch.Tensor | None = None, learning_rate_mean: float | None = None, learning_rate_var: float | None = None, learning_rate_B: float | None = None, covar_as_cholesky: bool = False, device: torch.device | None = None, ): """Initialize the xNES algorithm with the given parameters. :param init_mean: The initial mean vector of the population. Must be a 1D tensor. :param init_covar: The initial covariance matrix of the population. Must be a 2D tensor. :param pop_size: The size of the population. Defaults to None. :param recombination_weights: The recombination weights of the population. Defaults to None. :param learning_rate_mean: The learning rate for the mean vector. Defaults to None. :param learning_rate_var: The learning rate for the variance vector. Defaults to None. :param learning_rate_B: The learning rate for the B matrix. Defaults to None. :param covar_as_cholesky: Whether to use the covariance matrix as a Cholesky factorization result. Defaults to False. :param device: The device to use for the tensors. Defaults to None. """ super().__init__() dim = init_mean.shape[0] if pop_size is None: pop_size = 4 + math.floor(3 * math.log(self.dim)) assert pop_size > 0 if learning_rate_mean is None: learning_rate_mean = 1 if learning_rate_var is None: learning_rate_var = (9 + 3 * math.log(dim)) / 5 / math.pow(dim, 1.5) if learning_rate_B is None: learning_rate_B = learning_rate_var assert learning_rate_mean > 0 and learning_rate_var > 0 and learning_rate_B > 0 if not covar_as_cholesky: init_covar = torch.linalg.cholesky(init_covar) if recombination_weights is None: recombination_weights = torch.arange(1, pop_size + 1) recombination_weights = torch.clip(math.log(pop_size / 2 + 1) - torch.log(recombination_weights), 0) recombination_weights = recombination_weights / torch.sum(recombination_weights) - 1 / pop_size assert ( recombination_weights[1:] <= recombination_weights[:-1] ).all(), "recombination_weights must be in descending order" # set hyperparameters self.learning_rate_mean = Parameter(learning_rate_mean, device=device) self.learning_rate_var = Parameter(learning_rate_var, device=device) self.learning_rate_B = Parameter(learning_rate_B, device=device) # set value recombination_weights = recombination_weights.to(device=device) self.dim = dim self.pop_size = pop_size self.recombination_weights = recombination_weights # setup init_mean = init_mean.to(device=device) init_covar = init_covar.to(device=device) sigma = torch.pow(torch.prod(torch.diag(init_covar)), 1 / self.dim) self.sigma = Mutable(sigma) self.mean = Mutable(init_mean) self.B = Mutable(init_covar / sigma)
[docs] def step(self): """Run one step of the xNES algorithm. The function will sample a population, evaluate their fitness, and then update the center and covariance of the algorithm using the sampled population. """ pass device = self.mean.device noise = torch.randn(self.pop_size, self.dim, device=device) population = self.mean + self.sigma * (noise @ self.B.T) fitness = self.evaluate(population) order = torch.argsort(fitness) fitness, noise = fitness[order], noise[order] weights = self.recombination_weights Ind = torch.eye(self.dim, device=device) grad_delta = torch.sum(weights[:, None] * noise, dim=0) grad_M = (weights * noise.T) @ noise - torch.sum(weights) * Ind grad_sigma = torch.trace(grad_M) / self.dim grad_B = grad_M - grad_sigma * Ind mean = self.mean + self.learning_rate_mean * self.sigma * self.B @ grad_delta sigma = self.sigma * torch.exp(self.learning_rate_var / 2 * grad_sigma) B = self.B @ torch.linalg.matrix_exp(self.learning_rate_B / 2 * grad_B) self.sigma = sigma self.mean = mean self.B = B
[docs] def record_step(self): return {"mean": self.mean, "sigma": self.sigma, "B": self.B}
[docs] class SeparableNES(Algorithm): """The implementation of the Separable NES algorithm. Reference: Natural Evolution Strategies (https://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf) """ def __init__( self, init_mean: torch.Tensor, init_std: torch.Tensor, pop_size: int | None = None, recombination_weights: torch.Tensor | None = None, learning_rate_mean: float | None = None, learning_rate_var: float | None = None, device: torch.device | None = None, ): """Initialize the Separable NES algorithm with the given parameters. :param init_mean: The initial mean vector of the population. Must be a 1D tensor. :param init_std: The initial standard deviation for each dimension. Must be a 1D tensor. :param pop_size: The size of the population. Defaults to None. :param recombination_weights: The recombination weights of the population. Defaults to None. :param learning_rate_mean: The learning rate for the mean vector. Defaults to None. :param learning_rate_var: The learning rate for the variance vector. Defaults to None. :param device: The device to use for the tensors. Defaults to None. """ super().__init__() dim = init_mean.shape[0] assert init_std.shape == (dim,) if pop_size is None: pop_size = 4 + math.floor(3 * math.log(self.dim)) assert pop_size > 0 if learning_rate_mean is None: learning_rate_mean = 1 if learning_rate_var is None: learning_rate_var = (3 + math.log(dim)) / 5 / math.sqrt(dim) assert learning_rate_mean > 0 and learning_rate_var > 0 if recombination_weights is None: recombination_weights = torch.arange(1, pop_size + 1) recombination_weights = torch.clip(math.log(pop_size / 2 + 1) - torch.log(recombination_weights), 0) recombination_weights = recombination_weights / torch.sum(recombination_weights) - 1 / pop_size assert recombination_weights.shape == (pop_size,) # set hyperparameters self.learning_rate_mean = Parameter(learning_rate_mean, device=device) self.learning_rate_var = Parameter(learning_rate_var, device=device) # set value recombination_weights = recombination_weights.to(device=device) self.dim = dim self.pop_size = pop_size self.weight = recombination_weights # setup init_std = init_std.to(device=device) init_mean = init_mean.to(device=device) self.mean = Mutable(init_mean) self.sigma = Mutable(init_std)
[docs] def step(self): """Run one step of the Separable NES algorithm. The function will sample a population, evaluate their fitness, and then update the center and covariance of the algorithm using the sampled population. """ device = self.mean.device zero_mean_pop = torch.randn(self.pop_size, self.dim, device=device) population = self.mean + zero_mean_pop * self.sigma fitness = self.evaluate(population) order = torch.argsort(fitness) fitness, population, zero_mean_pop = fitness[order], population[order], zero_mean_pop[order] weight = torch.tile(self.weight[:, None], (1, self.dim)) grad_μ = torch.sum(weight * zero_mean_pop, dim=0) grad_sigma = torch.sum(weight * (zero_mean_pop * zero_mean_pop - 1), dim=0) mean = self.mean + self.learning_rate_mean * self.sigma * grad_μ sigma = self.sigma * torch.exp(self.learning_rate_var / 2 * grad_sigma) self.mean = mean self.sigma = sigma
[docs] def record_step(self): return {"mean": self.mean, "sigma": self.sigma}