Source code for evox.core.algorithm

import types
from typing import Tuple

import jax

from .module import *
from .state import State


[docs] class Algorithm(Stateful): """Base class for all algorithms"""
[docs] def ask(self, state: State) -> Tuple[jax.Array, State]: """Ask the algorithm Ask the algorithm for points to explore Parameters ---------- state The state of this algorithm. Returns ------- population The candidate solution. state The new state of the algorithm. """ return jnp.zeros(0), State()
[docs] def tell(self, state: State, fitness: jax.Array) -> State: """Tell the algorithm more information Tell the algorithm about the points it chose and their corresponding fitness Parameters ---------- state The state of this algorithm fitness The fitness Returns ------- state The new state of the algorithm """ return State()
def has_init_ask(algorithm): # def init_ask(self, state: State) -> Tuple[jax.Array, State]: # """Ask the algorithm for the initial population # Override this method if you need to initialize the population in a special way. # For example, Genetic Algorithm needs to evaluate the fitness of the initial population of size N, # but after that, it only need to evaluate the fitness of the offspring of size M, and N != M. # Since JAX requires the function return to have static shape, we need to have two different functions, # one is the normal `ask` and another is `init_ask`. # Parameters # ---------- # state # The state of this algorithm. # Returns # ------- # population # The candidate solution. # state # The new state of the algorithm. # """ # return None, State() return hasattr(algorithm, "init_ask") and callable(algorithm.init_ask) def has_init_tell(algorithm): # def init_tell(self, state: State, fitness: jax.Array) -> State: # """Tell the algorithm the fitness of the initial population # Use in pair with `init_ask`. # Parameters # ---------- # state # The state of this algorithm # fitness # The fitness # Returns # ------- # state # The new state of the algorithm # """ # return State() return hasattr(algorithm, "init_tell") and callable(algorithm.init_tell)