Skip to content

Solving a Brax environment using EvoTorch

This notebook demonstrates how the Brax environment named humanoid can be solved using EvoTorch.

EvoTorch provides VecGymNE, a neuroevolution problem type that focuses on solving vectorized environments. If GPU is available, VecGymNE can utilize it to boost performance. In this notebook, we use VecGymNE to solve the humanoid task.

For this notebook to work, the libraries JAX and Brax are required. For installing JAX, you might want to look at its official installation instructions. After a successful installation of JAX, Brax can be installed via:

pip install brax

Below, we import the necessary libraries.

from evotorch.algorithms import PGPE
from evotorch.neuroevolution import VecGymNE
from evotorch.logging import StdOutLogger, PicklingLogger

import os
import torch
from torch import nn

We now check if CUDA is available. If it is, we prepare a configuration which will tell VecGymNE to use a single GPU both for the population and for the fitness evaluation operations. If CUDA is not available, we will instead turn to actor-based parallelization on the CPU to boost the performance.

if torch.cuda.is_available():
    # CUDA is available. Here, we prepare GPU-specific settings.

    # We tell XLA (the backend of JAX) to use half of a GPU.
    os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".5"

    # We make only one GPU visible.
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

    # This is the device on which the population will be stored
    device = "cuda:0"

    # We do not want multi-actor parallelization.
    # For most basic brax tasks, it is enough to use a single GPU both
    # for the population and for the solution evaluations.
    num_actors = 0
    # Since CUDA is not available, the device of the population will be cpu.
    device = "cpu"

    # Use all the CPUs to speed-up the evaluations.
    num_actors = "max"

    # Because we are already using all the CPUs for actor-based parallelization,
    # we tell XLA not to use multiple threads for its operations.
    # (Following the suggestions at
    os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"

    # We also tell OpenBLAS and MKL to use only 1 thread for their operations.
    os.environ["OPENBLAS_NUM_THREADS"] = "1"
    os.environ["MKL_NUM_THREADS"] = "1"

We now define our policy. The policy can be expressed as a string, or as an instance or as a subclass of torch.nn.Module.

# --- A simple linear policy ---
policy = "Linear(obs_length, act_length)"

# --- A feed-forward network ---
# policy = "Linear(obs_length, 64) >> Tanh() >> Linear(64, act_length)"

# --- A feed-forward network with layer normalization ---
# policy = (
#     """
#     Linear(obs_length, 64)
#     >> Tanh()
#     >> LayerNorm(64, elementwise_affine=False)
#     >> Linear(64, act_length)
#     """
# )

# --- A recurrent network with layer normalization ---
# Note: in addition to RNN, LSTM is also supported
# policy = (
#     """
#     RNN(obs_length, 64)
#     >> LayerNorm(64, elementwise_affine=False)
#     >> Linear(64, act_length)
#     """
# )

# --- A manual feed-forward network ---
# class MyManualNetwork(nn.Module):
#     def __init__(self):
#         super().__init__()
#         ...
#    def forward(self, x: torch.Tensor) -> torch.Tensor:
#        ...
# policy = MyManualNetwork

# --- A manual recurrent network ---
# class MyManualRecurrentNetwork(nn.Module):
#     def __init__(self):
#         super().__init__()
#         ...
#     def forward(self, x: torch.Tensor, hidden_state = None) -> tuple:
#         ...
#         output_tensor = ...
#         new_hidden_state = ...  # hidden state could be a tensor, or a tuple or dict of tensors
#         return output_tensor, new_hidden_state
# policy = MyManualRecurrentNetwork

Below, we instantiate our VecGymNE problem.

ENV_NAME = "brax::humanoid"  # solve the brax task named "humanoid"
# ENV_NAME = "brax::old::humanoid"  # solve the "humanoid" task defined within 'brax.v1`

problem = VecGymNE(
    # Collect observation stats, and use those stats to normalize incoming observations
    # In the case of the "humanoid" task, the agent receives an "alive bonus" of 5.0 for each
    # non-terminal state it observes. In this example, we cancel out this fixed amount of
    # alive bonus using the keyword argument `decrease_rewards_by`.
    # The amount of alive bonus changes from task to task (some of them don't have this bonus
    # at all).
    # As an alternative to giving a fixed amount of alive bonus, we now enable a scheduled
    # alive bonus.
    # From timestep 0 to 400, the agents will receive no alive bonus.
    # From timestep 400 to 700, the agents will receive partial alive bonus.
    # Beginning with timestep 700, the agents will receive full (10.0) alive bonus.
    alive_bonus_schedule=(400, 700, 10.0),

problem, problem.solution_length

Note. At the time of writing this (15 June 2023), the arXiv paper of EvoTorch reports results based on the old implementations of the brax tasks (which were the default until brax v0.1.2). In brax version v0.9.0, these old task implementations moved into the namespace brax.v1. If you wish to reproduce the results reported in the arXiv paper of EvoTorch, you might want to specify the environment name as "brax::old::humanoid" (where the substring "old::" causes VecGymNE to instantiate the environment using the namespace brax.v1), so that you will observe scores and execution times compatible with the ones reported in that arXiv paper.

Initialize a PGPE to work on the problem.

Note: If you receive memory allocation error from the GPU driver, you might want to try again with: - a decreased popsize - a policy with decreased hidden size and/or number of layers (in case the policy is a neural network)

RADIUS = 2.25

# Instantiate a PGPE using the hyperparameters prepared above
searcher = PGPE(
    optimizer_config={"max_speed": MAX_SPEED},


We register two loggers for our PGPE instance.

  • StdOutLogger: A logger which will print out the status of the optimization.
  • PicklingLogger: A logger which will periodically save the latest result into a pickle file.
_ = StdOutLogger(searcher)
pickler = PicklingLogger(searcher, interval=5, directory="humanoid_results")

We are now ready to start the evolutionary search.
print("The run is finished.")
print("The pickle file that contains the latest result is:")

Now, we receive our trained policy as a torch module.

center_solution = searcher.status["center"]
policy = problem.to_policy(center_solution)

Visualizing the trained policy

Now that we have our final policy, we manually run and visualize it.

import jax

import brax

if ENV_NAME.startswith("brax::old::"):
    import brax.v1
    import brax.v1.envs
    import brax.v1.jumpy as jp
    from import html
    from import image
        import jumpy as jp
    except ImportError:
        import brax.jumpy as jp
    import brax.envs
    from import html
    from import image

from IPython.display import HTML, Image

import numpy as np

from typing import Iterable, Optional

import random

Below, we define a utility function named use_policy(...).

The expected arguments of use_policy(...) are as follows:

  • torch_module: The policy object, expected as a nn.Module instance.
  • x: The observation, as an iterable of real numbers.
  • h: The hidden state of the module, if such a state exists and if the module is recurrent. Otherwise, it can be left as None.

The return values of this function are as follows:

  • The action recommended by the policy, as a numpy array
  • The hidden state of the module, if the module is a recurrent one.
def use_policy(torch_module: nn.Module, x: Iterable, h: Optional[Iterable] = None) -> tuple:
    x = torch.as_tensor(np.array(x), dtype=torch.float32)
    if h is None:
        result = torch_module(x)
        result = torch_module(x, h)

    if isinstance(result, tuple):
        x, h = result
        x = x.numpy()
        x = result.numpy()
        h = None

    return x, h

We now initialize a new instance of our brax environment, and trigger the jit compilation on its reset and step methods.

if ENV_NAME.startswith("brax::old::"):
    env = brax.v1.envs.create(env_name=ENV_NAME[11:])
    env = brax.envs.create(env_name=ENV_NAME[6:])

reset = jax.jit(env.reset)
step = jax.jit(env.step)

Below we run our policy and collect the states of the episodes.

seed = random.randint(0, (2 ** 32) - 1)

if hasattr(jp, "random_prngkey"):
    state = reset(rng=jp.random_prngkey(seed=seed))
    state = reset(rng=jax.random.PRNGKey(seed=seed))

h = None
states = []
cumulative_reward = 0.0

while True:
    action, h = use_policy(policy, state.obs, h)
    state = step(state, action)
    cumulative_reward += float(state.reward)
    if np.abs(np.array(state.done)) > 1e-4:

Length of the episode and the total reward:

len(states), cumulative_reward

Visualization of the policy:

if ENV_NAME.startswith("brax::old::"):
    env_sys = env.sys
    states_to_render = [state.qp for state in states]
    env_sys = env.sys.replace(dt=env.dt)
    states_to_render = [state.pipeline_state for state in states]
HTML(html.render(env_sys, states_to_render))

See this notebook on GitHub