Custom algorithm and problem#
This this notebook, we will show how to use the Algorithm and Problem classes to create a custom algorithm and problem.
We will use the one-max problem as an example.
The one-max problem is a simple problem where the goal is to maximize the number of ones in a binary string.
For example, the string 01011 has a fitness of 3.
# install evox, skip it if you have already installed evox
try:
import evox
except ImportError:
!pip install --disable-pip-version-check --upgrade -q evox
import evox
from evox import Algorithm, Problem, State, jit_class, monitors, workflows
from evox.operators import mutation, crossover, selection
from jax import random
import jax.numpy as jnp
@jit_class
class OneMax(Problem):
def __init__(self) -> None:
super().__init__()
def evaluate(self, state, bitstrings):
# bitstrings has shape (pop_size, num_bits)
# so sum along the axis 1.
fitness = jnp.sum(bitstrings, axis=1)
return fitness, state
@jit_class
class CustomGA(Algorithm):
def __init__(self, pop_size, ndim, flip_prob):
super().__init__()
# those are hyperparameters that stay fixed.
self.pop_size = pop_size
self.ndim = ndim
# the probability of fliping each bit
self.flip_prob = flip_prob
def setup(self, key):
# initialize the state
# state are mutable data like the population, offsprings
# the population is randomly initialized.
# we don't have any offspring now, but initialize it as a placeholder
# because jax want static shaped arrays.
key, subkey = random.split(key)
pop = random.uniform(subkey, (self.pop_size, self.ndim)) < 0.5
return State(
pop=pop,
offsprings=jnp.empty((self.pop_size * 2, self.ndim)),
fit=jnp.full((self.pop_size,), jnp.inf),
key=key,
)
def ask(self, state):
key, mut_key, x_key = random.split(state.key, 3)
# here we do mutation and crossover (reproduction)
# for simplicity, we didn't use any mating selections
# so the offspring is twice as large as the population
offsprings = jnp.concatenate(
(
mutation.bitflip(mut_key, state.pop, self.flip_prob),
crossover.one_point(x_key, state.pop),
),
axis=0,
)
# return the candidate solution and update the state
return offsprings, state.update(offsprings=offsprings, key=key)
def tell(self, state, fitness):
# here we do selection
merged_pop = jnp.concatenate([state.pop, state.offsprings])
merged_fit = jnp.concatenate([state.fit, fitness])
new_pop, new_fit = selection.topk_fit(merged_pop, merged_fit, self.pop_size)
# replace the old population
return state.update(pop=new_pop, fit=new_fit)
algorithm = CustomGA(
pop_size=128,
ndim=100,
flip_prob=0.1,
)
problem = OneMax()
monitor = monitors.EvalMonitor()
# create a workflow
workflow = workflows.StdWorkflow(
algorithm,
problem,
monitors=[monitor],
opt_direction="max",
)
# init the workflow
key = random.PRNGKey(42)
state = workflow.init(key)
# run the workflow for 20 iterations
for i in range(20):
state = workflow.step(state)
monitor.get_best_fitness()
Array(-93, dtype=int32)
monitor.get_best_solution()
Array([ True, True, True, True, True, True, False, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, False, True, True, True, True,
True, True, True, True, False, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, False, True, True, True, True, True, True,
True, True, True, True, True, True, True, False, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, False, True, True, True, True,
True, True, False, True, True, True, True, True, True,
True], dtype=bool)
# run the workflow for another 20 iterations
for i in range(20):
state = workflow.step(state)
monitor.get_best_fitness()
Array(-100, dtype=int32)
monitor.get_best_solution()
Array([ True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True], dtype=bool)