Visualization of the brax
Experiment Results¶
Using this notebook, you can see the visualization of the agent trained by the notebook Brax_Experiments_with_PGPE.ipynb.
Below, we put the values of the center solution into the policy network as parameters:
Visualizing the trained policy¶
Now that we have our final policy, we manually run and visualize it.
if GOT_OLD_BRAX_ENV:
import brax.v1 as brax
import brax.v1.envs as brax_envs
import brax.v1.jumpy as jp
from brax.v1.jumpy import random_prngkey
from brax.v1.io import html, image
else:
import brax
import brax.envs as brax_envs
from jax.random import PRNGKey as random_prngkey
from brax.io import html, image
Below, we define a utility function named use_policy(...)
.
The expected arguments of use_policy(...)
are as follows:
torch_module
: The policy object, expected as ann.Module
instance.x
: The observation, as an iterable of real numbers.h
: The hidden state of the module, if such a state exists and if the module is recurrent. Otherwise, it can be left as None.
The return values of this function are as follows:
- The action recommended by the policy, as a numpy array
- The hidden state of the module, if the module is a recurrent one.
@torch.no_grad()
def use_policy(torch_module: nn.Module, x: Iterable, h: Optional[Iterable] = None) -> tuple:
x = torch.as_tensor(np.asarray(x), dtype=torch.float32)
if h is None:
result = torch_module(x)
else:
result = torch_module(x, h)
if isinstance(result, tuple):
x, h = result
x = x.numpy()
else:
x = result.numpy()
h = None
return x, h
We now initialize a new instance of our brax environment, and trigger the jit compilation on its reset
and step
methods.
Below we run our policy and collect the states of the episodes.
seed = random.randint(0, (2 ** 32) - 1)
state = reset(rng=random_prngkey(seed=seed))
h = None
states = []
cumulative_reward = 0.0
while True:
action, h = use_policy(policy, state.obs, h)
action = jnp.asarray(action)
state = step(state, action)
cumulative_reward += float(state.reward)
states.append(state)
if np.abs(np.array(state.done)) > 1e-4:
break
Length of the episode and the total reward:
def pipeline_state(state):
if hasattr(state, "qp"):
return state.qp
elif hasattr(state, "pipeline_state"):
return state.pipeline_state
else:
assert False
Visualization of the policy:
pipeline_states = [pipeline_state(state) for state in states]
if hasattr(env.sys, "tree_replace"):
env_sys = env.sys.tree_replace({'opt.timestep': env.dt})
else:
env_sys = env.sys