Custom algorithms and problems in EvoX#

In this chapter, we will introduce how to implement your own algorithm in EvoX.

The Algorithm Class#

The Algorithm class is inherited from Stateful. Besides the things in Stateful, you should also implement an ask and a tell method. In total, there are four methods one needs to implement.

Method

Signature

Usage

__init__

(self, ...)

Initialize hyperparameters that are fixed though out the optimization process, for example, the population size.

setup

(self, RRNGKey) -> State

Initialize mutable state, for example the momentum.

ask

(self, State) -> Array, State

Gives a candidate population for evaluation.

tell

(self, State, Array) -> State

Receive the fitness for the candidate population and update the algorithm’s state.

init_ask (Optional)

(self, State) -> Array, State

Gives initial population for evaluation. The population can have different shape than ask.

init_tell (Optional)

(self, State, Array) -> State

Receive the fitness for the initial population and update the algorithm’s state.

Migrate from traditional EC libraries#

In a traditional EC library, algorithms usually call the objective function internally, which gives the following layout

Algorithm
|
+--Problem

But in EvoX, we have a flat layout

Algorithm.ask -- Problem.evaluate -- Algorithm.tell

Here is a pseudocode of a genetic algorithm.

Set hyperparameters
Generate the initial population
Do
    Generate Offspring
        Selection
        Crossover
        Mutation
    Compute fitness
    Replace the population
Until stopping criterion

And here is what each part of the algorithm corresponds to in EvoX.

Set hyperparameters # __init__
Generate the initial population # setup
Do
    # ask
    Generate Offspring
        Mating Selection
        Crossover
        Mutation

    # problem.evaluate (not part of the algorithm)
    Compute fitness

    # tell
    Survivor Selection
Until stopping criterion

The Problem Class#

The Problem class is quite simple, beside __init__ and setup, the only required method is `evaluate``.

Migrate from traditional EC libraries#

There is one thing to notice here, evaluate is a stateful function, meaning it should accept a state and return a new state. So, if you are working with numerical benchmark functions, which don’t need to be stateful, you can simply ignore the state, but remember that you still have to use this stateful interface.

Method

Signature

Usage

__init__

(self, ...)

Initialize the settings of the problem.

setup

(self, RRNGKey) -> State

Initialize mutable state of this problem.

evaluate

(self, State, Array) -> Array, State

Evaluate the fitness of the given candidate solution.

More on the problem’s state#

If you still wonder what the problem’s state actually does, here are the explanations.

Unlike numerical benchmark functions, real-life problems are more complex and may require stateful computations. Here are some examples:

  • When dealing with ANN training, we often have the training, validation and testing phases. This implies that the same solution could have different fitness values during different phases. So clearly, we can’t model the evaluate as a stateless pure function anymore. To implement this mechanism, simply put a value in the state to indicate the phase.

  • Virtual batch norm is an effective trick especially when dealing with RL tasks. To implement this mechanism, the problem must be stateful, as the problem has to remember the initial batch norm parameters during the first run.

Example#

Here we give an example of implementing the OneMax problem, along with a genetic algorithm that solves this problem. The problem itself is straightforward, the fitness is defined as the sum of every digit in a fixed-length bitstring. For example, “100111” gives 4 and “000101” gives 2.

Let’s start with implementing the OneMax problem. In JAX a bitstring can be easily represented with a tensor of type bool.

import jax.numpy as jnp
from evox import Problem, jit_class


@jit_class
class OneMax(Problem):
    def __init__(self, neg_fitness=True) -> None:
        super().__init__()
        self.neg_fitess = neg_fitness

    def evaluate(self, state, bitstrings):
        # bitstrings has shape (pop_size, num_bits)
        # so sum along the axis 1.
        fitness = jnp.sum(bitstrings, axis=1)
        # Since in EvoX, algorithms try to minimize the fitness
        # so return the negitive value.
        if self.neg_fitess:
            fitness = -fitness
        return fitness, state

Then we implement a genetic algorithm that uses bitflip mutation and one-point crossover.

@jit_class
class ExampleGA(Algorithm):
    def __init__(self, pop_size, ndim, flip_prob):
        super().__init__()
        # those are hyperparameters that stay fixed.
        self.pop_size = pop_size
        self.ndim = ndim
        # the probability of fliping each bit
        self.flip_prob = flip_prob

    def setup(self, key):
        # initialize the state
        # state are mutable data like the population, offsprings
        # the population is randomly initialized.
        # we don't have any offspring now, but initialize it as a placeholder
        # because jax want static shaped arrays.
        key, subkey = random.split(key)
        pop = random.uniform(subkey, (self.pop_size, self.ndim)) < 0.5
        return State(
            pop=pop,
            offsprings=jnp.empty((self.pop_size * 2, self.ndim)),
            fit=jnp.full((self.pop_size,), jnp.inf),
            key=key,
        )

    def init_ask(self, state):
        # initial the fitness for our initial population
        return pop, state

    def init_tell(self, state, fitness):
        # update the fitness for the initial population
        return state.update(fit=fitness)

    def ask(self, state):
        key, mut_key, x_key = random.split(state.key, 3)
        # here we do mutation and crossover (reproduction)
        # for simplicity, we didn't use any mating selections
        # so the offspring is twice as large as the population
        offsprings = jnp.concatenate(
            (
                mutation.bitflip(mut_key, state.pop, self.flip_prob),
                crossover.one_point(x_key, state.pop),
            ),
            axis=0,
        )
        # return the candidate solution and update the state
        return offsprings, state.update(offsprings=offsprings, key=key)

    def tell(self, state, fitness):
        # here we do selection
        merged_pop = jnp.concatenate([state.pop, state.offsprings])
        merged_fit = jnp.concatenate([state.fit, fitness])
        new_pop, new_fit = selection.topk_fit(merged_pop, merged_fit, self.pop_size)
        # replace the old population
        return state.update(pop=new_pop, fit=new_fit)

Now, you can assemble a workflow and run it.