Skip to content

Visualization of the brax Experiment Results

Using this notebook, you can see the visualization of the agent trained by the notebook Brax_Experiments_with_PGPE.ipynb.

# Name of the pickle file saved by PicklingLogger goes here:
FNAME = "..."
# Name of the environment
ENV_NAME = "brax::humanoid"

import pickle
import torch
from torch import nn
with open(FNAME, "rb") as f:
    loaded = pickle.load(f)
# The unpickled object is a dictionary with these keys:
list(loaded.keys())
# Loaded center solution
center = loaded["center"]
center
# Loaded policy network
policy = loaded["policy"]
policy

Below, we put the values of the center solution into the policy network as parameters:

torch.nn.utils.vector_to_parameters(center, policy.parameters())

Visualizing the trained policy

Now that we have our final policy, we manually run and visualize it.

import jax
import jax.numpy as jnp
import numpy as np
assert ENV_NAME.startswith("brax::"), "This notebook can only work with brax environments"
GOT_OLD_BRAX_ENV = ENV_NAME.startswith("brax::old::")
if GOT_OLD_BRAX_ENV:
    import brax.v1 as brax
    import brax.v1.envs as brax_envs
    import brax.v1.jumpy as jp
    from brax.v1.jumpy import random_prngkey
    from brax.v1.io import html, image
else:
    import brax
    import brax.envs as brax_envs
    from jax.random import PRNGKey as random_prngkey
    from brax.io import html, image
from IPython.display import HTML, Image
from typing import Iterable, Optional
import random

Below, we define a utility function named use_policy(...).

The expected arguments of use_policy(...) are as follows:

  • torch_module: The policy object, expected as a nn.Module instance.
  • x: The observation, as an iterable of real numbers.
  • h: The hidden state of the module, if such a state exists and if the module is recurrent. Otherwise, it can be left as None.

The return values of this function are as follows:

  • The action recommended by the policy, as a numpy array
  • The hidden state of the module, if the module is a recurrent one.
@torch.no_grad()
def use_policy(torch_module: nn.Module, x: Iterable, h: Optional[Iterable] = None) -> tuple:
    x = torch.as_tensor(np.asarray(x), dtype=torch.float32)
    if h is None:
        result = torch_module(x)
    else:
        result = torch_module(x, h)

    if isinstance(result, tuple):
        x, h = result
        x = x.numpy()
    else:
        x = result.numpy()
        h = None

    return x, h

We now initialize a new instance of our brax environment, and trigger the jit compilation on its reset and step methods.

env = brax_envs.create(env_name="humanoid")
reset = jax.jit(env.reset)
step = jax.jit(env.step)

Below we run our policy and collect the states of the episodes.

seed = random.randint(0, (2 ** 32) - 1)
state = reset(rng=random_prngkey(seed=seed))

h = None
states = []
cumulative_reward = 0.0

while True:
    action, h = use_policy(policy, state.obs, h)
    action = jnp.asarray(action)

    state = step(state, action)
    cumulative_reward += float(state.reward)
    states.append(state)
    if np.abs(np.array(state.done)) > 1e-4:
        break

Length of the episode and the total reward:

len(states), cumulative_reward
def pipeline_state(state):
    if hasattr(state, "qp"):
        return state.qp
    elif hasattr(state, "pipeline_state"):
        return state.pipeline_state
    else:
        assert False

Visualization of the policy:

pipeline_states = [pipeline_state(state) for state in states]

if hasattr(env.sys, "tree_replace"):
    env_sys = env.sys.tree_replace({'opt.timestep': env.dt})
else:
    env_sys = env.sys
HTML(html.render(env_sys, pipeline_states))

See this notebook on GitHub