用EvoX解决Brax问题#

EvoX 深入研究使用 Brax 的神经演化。这里我们将展示一个在 EvoX 中解决 Brax 问题的例子。

# install EvoX and Brax, skip it if you have already installed EvoX or Brax
from importlib.util import find_spec
from IPython.display import HTML

if find_spec("evox") is None:
    %pip install evox
if find_spec("brax") is None:
    %pip install brax
# The dependent packages or functions in this example
import torch
import torch.nn as nn

from evox.algorithms import PSO
from evox.problems.neuroevolution.brax import BraxProblem
from evox.utils import ParamsAndVector
from evox.workflows import EvalMonitor, StdWorkflow

使用 EvoX 解决 Neuroevolution 任务#

神经演化是一种优化方法,它将神经网络与演化算法结合起来,以演化神经网络的结构和参数。通过模拟自然选择和遗传机制,神经演化旨在优化神经网络的架构和权重,解决复杂问题,如游戏AI、机器人控制等。

在我们的神经演化任务示例中,需要使用 Brax。因此,如果您想复制此示例,建议安装 Brax。

什么是Brax#

Brax 是一个快速且完全可微分的物理引擎,用于机器人学、人类感知、材料科学、强化学习和其他需要大量模拟的应用的研究和开发。

在这里,我们将演示Brax的“swimmer”环境。有关更多信息,您可以浏览Github of Brax

设计一个神经网络类#

要开始,我们需要决定要构建哪个神经网络。

这里我们将给出一个简单的多层感知器(MLP)类。

# Construct an MLP using PyTorch.
# This MLP has 3 layers.


class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.features = nn.Sequential(nn.Linear(8, 4), nn.Tanh(), nn.Linear(4, 2))

    def forward(self, x):
        x = self.features(x)
        return torch.tanh(x)

初始化模型#

通过SimpleMLP类,我们可以初始化一个MLP模型。

# Make sure that the model is on the same device, better to be on the GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Reset the random seed
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Initialize the MLP model
model = SimpleMLP().to(device)

启动适配器#

一个转换器可以帮助我们将数据在不同形式间来回转换。

  • to_vector 可以将参数字典转换为向量。

  • to_params 可以将一个向量转换回参数字典。

还有批量版本的转换。

adapter = ParamsAndVector(dummy_model=model)

使用适配器,我们可以开始进行这个神经演化任务。

设置运行过程#

初始化一个算法和一个问题#

我们初始化一个PSO 算法,问题是Brax 问题中的“swimmer”环境。

# Set the population size
POP_SIZE = 1024

# Get the bound of the PSO algorithm
model_params = dict(model.named_parameters())
pop_center = adapter.to_vector(model_params)
lower_bound = torch.full_like(pop_center, -5)
upper_bound = torch.full_like(pop_center, 5)

# Initialize the PSO, and you can also use any other algorithms
algorithm = PSO(
    pop_size=POP_SIZE,
    lb=lower_bound,
    ub=upper_bound,
    device=device,
)
algorithm.setup()

# Initialize the Brax problem
problem = BraxProblem(
    policy=model,
    env_name="swimmer",
    max_episode_length=1000,
    num_episodes=3,
    pop_size=POP_SIZE,
    device=device,
)

在这种情况下,我们将为每个 episode 使用 1000 步,并返回 3 个 episode 的平均奖励作为适应度值。

设置一个monitor#

# set an monitor, and it can record the top 3 best fitnesses
monitor = EvalMonitor(
    topk=3,
    device=device,
)
monitor.setup()
EvalMonitor()

启动一个工作流#

# Initiate an workflow
workflow = StdWorkflow(opt_direction="max")
workflow.setup(
    algorithm=algorithm,
    problem=problem,
    solution_transform=adapter,
    monitor=monitor,
    device=device,
)

运行工作流#

运行工作流并见证魔法!

备注

以下代码块大约需要运行20分钟。运行时间可能会因您的硬件而有所不同。

# Set the maximum number of generations
max_generation = 50

# Run the workflow
for i in range(max_generation):
    if i % 10 == 0:
        print(f"Generation {i}")
    workflow.step()

monitor = workflow.get_submodule("monitor")
print(f"Top fitness: {monitor.get_best_fitness()}")
best_params = adapter.to_params(monitor.get_best_solution())
print(f"Best params: {best_params}")
Generation 0
Generation 10
Generation 20
Generation 30
Generation 40
Top fitness: 369.4692077636719
Best params: {'features.0.weight': tensor([[ 2.3992, -1.8511,  4.8109, -4.4597, -1.0910,  1.4677, -4.9631,  5.0000],
        [-5.0000, -2.5050,  2.4442, -3.0992, -0.8043,  3.4015, -5.0000, -4.4697],
        [-4.9733,  3.3274,  2.6283, -1.8122, -4.9979, -5.0000, -4.2314, -1.4714],
        [-4.0897,  5.0000, -5.0000,  4.6735,  5.0000, -5.0000,  5.0000,  1.5789]],
       device='cuda:0'), 'features.0.bias': tensor([ 2.6846, -4.9499,  3.2993,  3.6577], device='cuda:0'), 'features.2.weight': tensor([[-0.5519,  5.0000, -3.8752,  5.0000],
        [-3.2433, -4.9600, -0.1063,  2.1125]], device='cuda:0'), 'features.2.bias': tensor([ 1.4154, -0.5969], device='cuda:0')}
monitor.get_best_fitness()
tensor(369.4692, device='cuda:0')
monitor.plot()