evox.problems.neuroevolution.brax#

Module Contents#

Classes#

BraxProblem

The Brax problem wrapper.

Functions#

Data#

API#

evox.problems.neuroevolution.brax.__all__#

[‘BraxProblem’]

evox.problems.neuroevolution.brax.to_jax_array(x: torch.Tensor) jax.Array#
evox.problems.neuroevolution.brax.from_jax_array(x: jax.Array, device: Optional[torch.device] = None) torch.Tensor#
evox.problems.neuroevolution.brax.__brax_data__: Dict[int, Tuple[Callable[[jax.Array], brax.envs.State], Callable[[brax.envs.State, jax.Array], brax.envs.State], Callable[[Dict[str, torch.Tensor], torch.Tensor], Tuple[Dict[str, torch.Tensor], torch.Tensor]], List[str]]]#

None

evox.problems.neuroevolution.brax._evaluate_brax_main(env_id: int, pop_size: int, rotate_key: bool, num_episodes: int, max_episode_length: int, key: torch.Tensor, model_state: List[torch.Tensor]) Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]#
evox.problems.neuroevolution.brax._evaluate_brax(env_id: int, pop_size: int, rotate_key: bool, num_episodes: int, max_episode_length: int, key: torch.Tensor, model_state: List[torch.Tensor]) Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]#
evox.problems.neuroevolution.brax._fake_evaluate_brax(env_id: int, pop_size: int, rotate_key: bool, num_episodes: int, max_episode_length: int, key: torch.Tensor, model_state: List[torch.Tensor]) Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]#
evox.problems.neuroevolution.brax._evaluate_brax_vmap_main(batch_size: int, in_dim: List[int], env_id: int, pop_size: int, rotate_key: bool, num_episodes: int, max_episode_length: int, key: torch.Tensor, model_state: List[torch.Tensor]) Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]#
evox.problems.neuroevolution.brax._evaluate_brax_vmap(vmap_info: evox.utils.VmapInfo, in_dims: Tuple[int | None | List[int], ...], env_id: int, pop_size: int, rotate_key: bool, num_episodes: int, max_episode_length: int, key: torch.Tensor, model_state: List[torch.Tensor]) Tuple[Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor], Tuple[int | None, List[int], int]]#
evox.problems.neuroevolution.brax._fake_evaluate_brax_vmap(batch_size: int, in_dim: List[int], env_id: int, pop_size: int, rotate_key: bool, num_episodes: int, max_episode_length: int, key: torch.Tensor, model_state: List[torch.Tensor]) Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]#
class evox.problems.neuroevolution.brax.BraxProblem(policy: torch.nn.Module, env_name: str, max_episode_length: int, num_episodes: int, seed: int = None, pop_size: int | None = None, rotate_key: bool = True, reduce_fn: Callable[[torch.Tensor, int], torch.Tensor] = torch.mean, backend: str | None = None, device: torch.device | None = None)#

Bases: evox.core.Problem

The Brax problem wrapper.

Initialization

Construct a Brax-based problem. Firstly, you need to define a policy model. Then you need to set the environment name <https://github.com/google/brax/tree/main/brax/envs>, the maximum episode length, the number of episodes to evaluate for each individual. For each individual, it will run the policy with the environment for num_episodes times with different seed, and use the reduce_fn to reduce the rewards (default to average). Different individuals will share the same set of random keys in each iteration.

Parameters:
  • policy – The policy model whose forward function is :code:forward(batched_obs) -> action.

  • env_name – The environment name.

  • max_episode_length – The maximum number of time steps of each episode.

  • num_episodes – The number of episodes to evaluate for each individual.

  • seed – The seed used to create a PRNGKey for the brax environment. When None, randomly select one. Default to None.

  • pop_size – The size of the population to be evaluated. If None, we expect the input to have a population size of 1.

  • rotate_key – Indicates whether to rotate the random key for each iteration (default is True).
    If True, the random key will rotate after each iteration, resulting in non-deterministic and potentially noisy fitness evaluations. This means that identical policy weights may yield different fitness values across iterations.
    If False, the random key remains the same for all iterations, ensuring consistent fitness evaluations.

  • reduce_fn – The function to reduce the rewards of multiple episodes. Default to torch.mean.

  • backend – Brax’s backend. If None, the default backend of the environment will be used. Default to None.

  • device – The device to run the computations on. Defaults to the current default device.

Notice

The initial key is obtained from torch.random.get_rng_state().

Warning

This problem does NOT support HPO wrapper (problems.hpo_wrapper.HPOProblemWrapper) out-of-box, i.e., the workflow containing this problem CANNOT be vmapped. However, by setting pop_size to the multiplication of inner population size and outer population size, you can still use this problem in a HPO workflow.

Examples

from evox import problems problem = problems.neuroevolution.Brax( … env_name=”swimmer”, … policy=model, … max_episode_length=1000, … num_episodes=3, … pop_size=100, … rotate_key=False, …)

_evaluate_brax_record(model_state: Dict[str, torch.Tensor]) Tuple[Dict[str, torch.Tensor], torch.Tensor, List[Any]]#
evaluate(pop_params: Dict[str, torch.nn.Parameter]) torch.Tensor#

Evaluate the final rewards of a population (batch) of model parameters.

Parameters:

pop_params – A dictionary of parameters where each key is a parameter name and each value is a tensor of shape (batch_size, *param_shape) representing the batched parameters of batched models.

Returns:

A tensor of shape (batch_size,) containing the reward of each sample in the population.

visualize(weights: Dict[str, torch.nn.Parameter], seed: int = 0, output_type: str = 'HTML', *args, **kwargs) str | torch.Tensor#

Visualize the brax environment with the given policy and weights.

Parameters:
  • weights – The weights of the policy model. Which is a dictionary of parameters.

  • output_type – The output type of the visualization, “HTML” or “rgb_array”. Default to “HTML”.

Returns:

The visualization output.