Multidevice Algorithm#
This guide will show you how to write algorithms that can run on multiple devices (multiple GPUs) in EvoX.
import jax
import jax.numpy as jnp
from evox import dataclass, pytree_field, problems, workflows, monitors, algorithms, use_state
from evox.core.distributed import ShardingType
from evox.utils import *
In this example, we consider the following simple setup:
Node1
|
+----+----+
| |
GPU GPU
Where we only have one node with multiple GPUs. The communication between the GPUs is done through the PCIe or NVLink. When running in a distributed setup, we need to make decisions on how to place the data on these GPUs.
Here, we use the vanilla PSO algorithm as an example. In PSO, each GPU can independently update the local information for its particles. On the other hand, updating the global information requires communication between GPUs, but this process can be handled rather efficiently using an all-reduce operation.
Here is an illustration of population in PSO, it has two dimensions: the number of particles (population size) and the problem dimension.
Problem Dimension
+-------------------+
| |
| GPU 0 | Population Size
| All particles |
| |
+-------------------+
After sharding it across the population dimension, we have the following:
Problem Dimension
+-------------------+
| |
| GPU 0 | Population Size / 2
+-------------------+
| |
| GPU 1 | Population Size / 2
+-------------------+
Similarly, we can also shard the velocity variable. This will reduce the memory usage on each GPU by half.
# The only change:
# Add the sharding metadata
@dataclass
class SpecialPSOState:
population: jax.Array = pytree_field(sharding=ShardingType.SHARED_FIRST_DIM)
velocity: jax.Array = pytree_field(sharding=ShardingType.SHARED_FIRST_DIM)
fitness: jax.Array = pytree_field(sharding=ShardingType.SHARED_FIRST_DIM)
local_best_location: jax.Array = pytree_field(sharding=ShardingType.SHARED_FIRST_DIM)
local_best_fitness: jax.Array = pytree_field(sharding=ShardingType.SHARED_FIRST_DIM)
global_best_location: jax.Array
global_best_fitness: jax.Array
key: jax.random.PRNGKey
# inherit from the base PSO algorithm
# and replace the State type with SpecialPSOState, which contains the sharding metadata
@dataclass
class PSO(algorithms.PSO):
def setup(self, key):
state_key, init_pop_key, init_v_key = jax.random.split(key, 3)
length = self.ub - self.lb
population = jax.random.uniform(
init_pop_key, shape=(self.pop_size, self.dim)
)
population = population * length + self.lb
velocity = jax.random.uniform(init_v_key, shape=(self.pop_size, self.dim))
velocity = velocity * length * 2 - length
return SpecialPSOState(
population=population,
velocity=velocity,
fitness=jnp.full((self.pop_size,), jnp.inf),
local_best_location=population,
local_best_fitness=jnp.full((self.pop_size,), jnp.inf),
global_best_location=population[0],
global_best_fitness=jnp.array([jnp.inf]),
key=state_key,
)
pso = PSO(
lb=jnp.full(shape=(2,), fill_value=-32),
ub=jnp.full(shape=(2,), fill_value=32),
pop_size=100,
)
ackley = problems.numerical.Ackley()
monitor = monitors.EvalMonitor()
workflow = workflows.StdWorkflow(
pso,
ackley,
monitors=[monitor],
)
key = jax.random.PRNGKey(42)
state = workflow.init(key)
state = workflow.enable_multi_devices(state)
# check if the state is correctly sharded
jax.tree.map(lambda x: x.sharding, state)
State(StdWorkflowState(generation=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec(), memory_kind=device), first_step=True), {'algorithm': State(SpecialPSOState(population=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec('POP',), memory_kind=device), velocity=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec('POP',), memory_kind=device), fitness=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec('POP',), memory_kind=device), local_best_location=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec('POP',), memory_kind=device), local_best_fitness=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec('POP',), memory_kind=device), global_best_location=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec(), memory_kind=device), global_best_fitness=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec(), memory_kind=device), key=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec(), memory_kind=device)), {}),'monitors0': State(EvalMonitorState(first_step=True, latest_solution=None, latest_fitness=None, topk_solutions=None, topk_fitness=None), {}),'problem': State({}, {})})
# run the workflow for 50 steps
for i in range(50):
state = workflow.step(state)
best_solution, _state = use_state(monitor.get_best_solution)(state)
print(best_solution)
[ 0.0002041 -0.00019218]