Custom algorithm and problem

Custom algorithm and problem#

This this notebook, we will show how to use the Algorithm and Problem classes to create a custom algorithm and problem. We will use the one-max problem as an example. The one-max problem is a simple problem where the goal is to maximize the number of ones in a binary string. For example, the string 01011 has a fitness of 3.

# install evox, skip it if you have already installed evox
try:
    import evox
except ImportError:
    !pip install --disable-pip-version-check --upgrade -q evox
    import evox
from evox import Algorithm, Problem, State, jit_class, monitors, workflows
from evox.operators import mutation, crossover, selection
from jax import random
import jax.numpy as jnp
@jit_class
class OneMax(Problem):
    def __init__(self) -> None:
        super().__init__()

    def evaluate(self, state, bitstrings):
        # bitstrings has shape (pop_size, num_bits)
        # so sum along the axis 1.
        fitness = jnp.sum(bitstrings, axis=1)
        return fitness, state
@jit_class
class CustomGA(Algorithm):
    def __init__(self, pop_size, ndim, flip_prob):
        super().__init__()
        # those are hyperparameters that stay fixed.
        self.pop_size = pop_size
        self.ndim = ndim
        # the probability of fliping each bit
        self.flip_prob = flip_prob

    def setup(self, key):
        # initialize the state
        # state are mutable data like the population, offsprings
        # the population is randomly initialized.
        # we don't have any offspring now, but initialize it as a placeholder
        # because jax want static shaped arrays.
        key, subkey = random.split(key)
        pop = random.uniform(subkey, (self.pop_size, self.ndim)) < 0.5
        return State(
            pop=pop,
            offsprings=jnp.empty((self.pop_size * 2, self.ndim)),
            fit=jnp.full((self.pop_size,), jnp.inf),
            key=key,
        )

    def ask(self, state):
        key, mut_key, x_key = random.split(state.key, 3)
        # here we do mutation and crossover (reproduction)
        # for simplicity, we didn't use any mating selections
        # so the offspring is twice as large as the population
        offsprings = jnp.concatenate(
            (
                mutation.bitflip(mut_key, state.pop, self.flip_prob),
                crossover.one_point(x_key, state.pop),
            ),
            axis=0,
        )
        # return the candidate solution and update the state
        return offsprings, state.update(offsprings=offsprings, key=key)

    def tell(self, state, fitness):
        # here we do selection
        merged_pop = jnp.concatenate([state.pop, state.offsprings])
        merged_fit = jnp.concatenate([state.fit, fitness])
        new_pop, new_fit = selection.topk_fit(merged_pop, merged_fit, self.pop_size)
        # replace the old population
        return state.update(pop=new_pop, fit=new_fit)
algorithm = CustomGA(
    pop_size=128,
    ndim=100,
    flip_prob=0.1,
)
problem = OneMax()
monitor = monitors.EvalMonitor()
# create a workflow
workflow = workflows.StdWorkflow(
    algorithm,
    problem,
    monitors=[monitor],
    opt_direction="max",
)
# init the workflow
key = random.PRNGKey(42)
state = workflow.init(key)

# run the workflow for 20 iterations
for i in range(20):
    state = workflow.step(state)
monitor.get_best_fitness()
Array(-93, dtype=int32)
monitor.get_best_solution()
Array([ True,  True,  True,  True,  True,  True, False,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True, False,  True,  True,  True,  True,
        True,  True,  True,  True, False,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True, False,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True, False,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True, False,  True,  True,  True,  True,
        True,  True, False,  True,  True,  True,  True,  True,  True,
        True], dtype=bool)
# run the workflow for another 20 iterations
for i in range(20):
    state = workflow.step(state)
monitor.get_best_fitness()
Array(-100, dtype=int32)
monitor.get_best_solution()
Array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True], dtype=bool)