Working with extended applications

Working with extended applications#

Working with extended applications in EvoX is easy.

# install evox, skip it if you have already installed evox
try:
    import evox
except ImportError:
    !pip install --disable-pip-version-check --upgrade -q evox brax
    import evox
from evox import algorithms, problems, workflows, monitors, utils, use_state

import jax.numpy as jnp
from jax import jit, random
from jax.tree_util import tree_map
from flax import linen as nn

from IPython.display import HTML, display

Neuroevolution Tasks#

Here we will be focusing on neuroevolution tasks, where one need to evolve a neural network that suits a certain tasks.

Brax#

To begin with we will be using Brax, a GPU accelerated physical engine that is also written in JAX. Since Brax is also using JAX, running EvoX with Brax is quite easy.

We will be demostrating using the “swimmer” environment in Brax.

First we will need to decide how we are going to evolve a neural network. In this case, we will be using a fixed-size ANN, and only evolve it’s weights.

# construct an ANN using flax.
# "swimmer" environment has 8 observations and 2 actions
# and the actions are in (-1.0, 1.0)
class SwimmerPolicy(nn.Module):
    """A simple model for Hopper"""

    @nn.compact
    def __call__(self, x):
        x = x.astype(jnp.float32)
        x = x.reshape(-1)
        x = nn.Dense(32)(x)
        x = nn.tanh(x)
        x = nn.Dense(32)(x)
        x = nn.tanh(x)
        x = nn.Dense(2)(x)
        x = nn.tanh(x)

        return x

model = SwimmerPolicy()
weights = model.init(random.PRNGKey(42), jnp.zeros((8, )))
print(tree_map(lambda x: x.shape, weights)) # print the structure of the weights
{'params': {'Dense_0': {'bias': (32,), 'kernel': (8, 32)}, 'Dense_1': {'bias': (32,), 'kernel': (32, 32)}, 'Dense_2': {'bias': (2,), 'kernel': (32, 2)}}}

However, if we check the weights for this network, we will see that it’s group of parameter sets, and EC algorithms cannot directly work with data in this format.

Thankfully, EvoX provides some useful utilities to help us bridge the gap, and in this case, we have TreeAndVector to help us convert a tree-like struct into a vector and back.

adapter = utils.TreeAndVector(weights)

Now, adapter can help us convert the data back-and-forth.

  • to_vector can convert a tree into a vector.

  • to_tree can convert a vector back to a tree.

There are also batched version conversion.

  • batched_to_vector can convert a batch of trees into a batch of vectors.

  • batched_to_tree can convert a batch of vectors into a batch of trees.

vector_form_weights = adapter.to_vector(weights)
print(vector_form_weights.shape) # it's a single vector!
(1410,)

Now we can create an algorithm object.

# we wish the weights to be in the range [-10, 10]
lower_bound = jnp.full_like(vector_form_weights, -10.0)
upper_bound = jnp.full_like(vector_form_weights, 10.0)

# You can also use any other algorithms
algorithm = algorithms.PSO(
    lb=lower_bound,
    ub=upper_bound,
    pop_size=1024, # don't worry, it's fast
    mean=vector_form_weights, # initialize the population around the current weights
    stdev=0.1, # and with a small gaussian noise
)

Now create brax-based problem. and max_episode_length is the maximum number of steps for each episode, num_episodes is the number of episodes to run for each evaluation. In this case, we will be using 1000 steps for each episode, and the average reward of 3 episodes will be returned as the fitness value.

problem = problems.neuroevolution.Brax(
    env_name="swimmer",
    policy=jit(model.apply),
    max_episode_length=1000,
    num_episodes=3,
    rotate_key=False, # The vanilla PSO doesn't handles noisy fitness values well, so we disable the key rotation, meaning that the same policy will always be evaluated with the same seed
)

Assemble our workflow and fire it!

Notice the solution_transforms option. It’s used to convert the candidate solutions into the tree-like structure that representing a neural network’s weight.

monitor = monitors.EvalMonitor()
workflow = workflows.StdWorkflow(
    algorithm,
    problem,
    monitors=[monitor],
    solution_transforms=[adapter.batched_to_tree],
    opt_direction="max", # we want to maximize the reward, by default it's "min", so we need to change it
)

Run the workflow and see the magic!

Note

The following block will take around 10 mins to run. The time may vary depending on your hardware.

state = workflow.init(random.PRNGKey(123))

# run the workflow for 100 iterations
for i in range(100):
    state = workflow.step(state)
best_weight, state = use_state(monitor.get_best_solution)(state)
# shout out to Brax's team for making the html renderer
html_result = problem.visualize(random.key(0), adapter.to_tree(best_weight))
# you can use display(HTML(html_result)) if your notebook supports it
monitor.plot()

The algorithm is making progress in optimizing the policy; however, the results are still far from ideal. It’s important to remember that this is just a demonstration, and vanilla PSO wasn’t designed for this type of task, so its performance limitations here are expected.