Tensorflow datasets in EvoX#
Use tensorflow-datasets to load machine learning dataset and train a model in EvoX.
from evox import workflows, algorithms, problems
from evox.monitors import EvalMonitor
from evox.utils import TreeAndVector, rank_based_fitness
import jax
import jax.numpy as jnp
from flax import linen as nn
from tqdm.notebook import trange
2024-04-24 11:26:20.868999: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
BATCH_SIZE = 128
class MyNet(nn.Module):
"""A simple model for mnist"""
@nn.compact
def __call__(self, x):
x = x / 255.0
x = x.reshape(BATCH_SIZE, -1)
x = nn.Dense(64)(x)
x = nn.relu(x)
x = nn.Dense(10)(x)
return jax.nn.softmax(x)
key = jax.random.PRNGKey(42)
model_key, workflow_key = jax.random.split(key)
model = MyNet()
params = model.init(model_key, jnp.zeros((BATCH_SIZE, 28, 28, 1)))
# define your loss function
# the function accepts 1. the weight 2. a batch of data
# the dictionary structure can be found at https://www.tensorflow.org/datasets/catalog/fashion_mnist
@jax.jit
def loss_func(weight, data):
# a very bad loss function
# please replace with your own
images, labels = data["image"], data["label"]
outputs = model.apply(weight, images)
labels = jax.nn.one_hot(labels, 10)
return jnp.mean((outputs - labels)**2)
# the TensorflowDataset
# Download and Prepare the dataset requires `tensorflow` to be installed
# After that `tensorflow` is not needed.
# the dataset is downloaded to ~/tensorflow-dataset
problem = problems.neuroevolution.TensorflowDataset(
dataset="fashion_mnist",
batch_size=BATCH_SIZE,
loss_func=loss_func,
try_gcs=False,
)
WARNING:absl:OpenCV is not installed. We recommend using OpenCV because it is faster according to our benchmarks. Defaulting to PIL to decode images...
adapter = TreeAndVector(params)
monitor = EvalMonitor()
center = adapter.to_vector(params)
# create a workflow
workflow = workflows.StdWorkflow(
algorithm=algorithms.PGPE(
optimizer="adam",
center_init=center,
pop_size=256,
stdev_init=0.1,
),
problem=problem,
sol_transforms=[adapter.batched_to_tree],
fit_transforms=[rank_based_fitness],
monitors=[monitor],
)
# init the workflow
state = workflow.init(workflow_key)
# run the workflow for 100 steps
for i in trange(100):
state = workflow.step(state)
best_fitness = monitor.get_best_fitness()
monitor.plot()