Classic control with Gym#
In this notebook, we will use the Gym to train an agent that solves classic control problems.
# install evox, skip it if you have already installed evox
try:
import evox
except ImportError:
!pip install --disable-pip-version-check --upgrade -q evox gymnasium flax
import evox
from evox import workflows, algorithms, problems
from evox.monitors import EvalMonitor
from evox.utils import TreeAndVector
import jax
import jax.numpy as jnp
from flax import linen as nn
gym_name = "Pendulum-v1" # choose a setup
def tanh2(x):
return 2 * nn.tanh(x)
policy_params = {
"Acrobot-v1": (3, (6,), jnp.argmax),
"CartPole-v1": (2, (4,), jnp.argmax),
"MountainCarContinuous-v0": (1, (2,), nn.tanh),
"MountainCar-v0": (3, (2,), jnp.argmax),
"Pendulum-v1": (1, (3,), tanh2),
}
# define a policy model
class ClassicPolicy(nn.Module):
"""A simple model for Classic Control problem"""
@nn.compact
def __call__(self, x):
x = x.at[1].multiply(10) # normalization
x = nn.Dense(16)(x)
x = nn.relu(x)
x = nn.Dense(policy_params[gym_name][0])(x)
return policy_params[gym_name][2](x)
key = jax.random.PRNGKey(42)
model_key, workflow_key = jax.random.split(key)
model = ClassicPolicy()
params = model.init(model_key, jnp.zeros(policy_params[gym_name][1]))
adapter = TreeAndVector(params)
monitor = EvalMonitor()
problem = problems.neuroevolution.Gym(
env_name=gym_name,
policy=jax.jit(model.apply),
num_workers=16, # adjust according to your need
controller_options={
"num_cpus": 0,
"num_gpus": 0,
},
worker_options={"num_cpus": 1, "num_gpus": 1 / 16},
batch_policy=False,
)
center = adapter.to_vector(params)
# create a workflow
workflow = workflows.StdWorkflow(
algorithm=algorithms.CMAES(center_init=center, init_stdev=1, pop_size=64),
problem=problem,
sol_transforms=[adapter.batched_to_tree],
monitors=[monitor],
opt_direction="max"
jit_problem=False,
)
2023-10-24 15:54:46,501 INFO worker.py:1553 -- Started a local Ray instance.
Now run the workflow. You may see warnings like
CUDA backend failed to initialize: Unable to load CUDA.
This is expected behaivor, because we have a controller thread that manages a group of Gym workers, and the controller thread does not use GPU.
If the program stucks, you may want to check whether is num_workers is larger than the number of available cores on your computer.
# init the workflow
state = workflow.init(workflow_key)
# run the workflow for 100 steps
for i in range(100):
state = workflow.step(state)
best_fitness = monitor.get_best_fitness()
print(best_fitness)
(Controller pid=641434) CUDA backend failed to initialize: Unable to load CUDA. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
-0.114485