Skip to content

Evolving objects using the functional operators API of EvoTorch

In this notebook, we show how to use the functional operators API of EvoTorch for tackling a problem with non-numeric solutions.

In the problem we consider, the goal is to evolve parameter tensors of a feed-forward neural network to make a simulated Ant-v4 MuJoCo robot walk forward. The feed-forward neural network policy has the following modules:

  • module 0: linear transformation (torch.nn.Linear) with a weight matrix and with a bias vector
  • module 1: tanh (torch.nn.Tanh)
  • module 2: linear transformation (torch.nn.Linear) with a weight matrix and with a bias vector

In this problem, instead of a fixed-length vector consisting of real numbers, a solution is represented by a dictionary structured like this:

{
    "0.weight": [ ... list of seeds ... ],
    "0.bias": [ ... list of seeds ... ],
    "2.weight": [ ... list of seeds ... ],
    "2.bias": [ ... list of seeds ... ],
}

where each key is a name referring to a parameter tensor. Associated with each key is a list of integers (integers being random seeds). At the moment of decoding a solution, each parameter tensor (e.g. "0.weight") is constructed by sampling a Gaussian noise using each seed, and then by summing those Gaussian noises (as was done in [1] and [2]).

Note 1: Although this example is inspired by the studies [1] and [2], it is not a faithful implementation of any them. Instead, this notebook focuses on demonstrating various features of the functional operators API of EvoTorch.

Note 2: For the sake of simplicity, the action space of Ant-v4 is binned. With this simplification and with its default hyperparameters, this example evolutionary algorithm is able to find gaits for the ant robot with a relatively small population size, although the evolved gaits will not be very efficient (i.e. non-competitive cumulative rewards).


[1] Felipe Petroski Such, Vashisht Madhavan, Edoardo Conti, Joel Lehman, Kenneth O. Stanley, Jeff Clune (2017). "Deep neuroevolution: Genetic algorithms are a competitive alternative for training deep neural networks for reinforcement learning." arXiv preprint arXiv:1712.06567.

[2] Risi, Sebastian, and Kenneth O. Stanley (2019). "Deep neuroevolution of recurrent and discrete world models." Proceedings of the Genetic and Evolutionary Computation Conference.


Summary of the evolutionary algorithm

We implement a simple, elitist genetic algorithm with tournament selection, cross-over, and mutation operators. The main ideas of this genetic algorithm are as follows.

Generation of a new solution: - Make a new dictionary. - Associated with each key (parameter name) within the dictionary, make a single-element list of seeds, the seed within it being a random integer.

Cross-over between two solutions. - Make two children solutions (dictionaries). - For each key (parameter name): - Sample a real number \(p\) between 0 and 1. - If \(p < 0.5\), the first child receives its list of seeds from the first parent, the second child receives its list of seeds from the second parent. - Otherwise, the first child receives its list of seeds from the second parent, the second child receives its list of seeds from the first parent.

Mutation of an existing solution. - Pick a key (parameter name) within the solution (dictionary). - Randomly sample a new integer, and add this integer into the list of seeds associated with the picked key.

Implementation

from evotorch import Problem, Solution
from evotorch.tools import make_tensor, ObjectArray
import evotorch.operators.functional as func_ops

import gymnasium as gym
import numpy as np

import torch
from torch import nn
from torch.func import functional_call

from typing import Iterable, Mapping, Optional, Union
import random
import os
from datetime import datetime
import pickle

The function below takes a series of seeds, and makes a tensor of real numbers out of them. We will use this function at the moment of decoding a solution.

def make_tensor_from_seeds(
    like: torch.Tensor,
    seeds: Iterable,
    *,
    mutation_power: float,
    mutation_decay: float,
    min_mutation_power: float,
) -> torch.Tensor:
    """
    Take a series of seeds and compute a tensor out of them.

    Args:
        like: A source tensor. The resulting tensor will have the same shape,
            dtype, and device with this source tensor.
        seeds: An iterable in which each item is an integer, each integer
            being a random seed.
        mutation_power: A multiplier for the Gaussian noise generated out of
            a random seed.
        mutation_decay: For each seed, the mutation power will be multiplied
            by this factor. For example, if this multiplier is 0.9, the power
            of the mutation will be decreased with each seed, as that power
            will be diminished by getting multiplied with 0.9.
        min_mutation_power: To prevent the mutation power from getting to
            close to 0, provide a lower bound multiplier via this argument.
    Returns:
        The tensor generated from the given seeds.
    """
    from numpy.random import RandomState

    result = torch.zeros_like(like)
    for i_seed, seed in enumerate(seeds):
        multiplier = max(mutation_power * (mutation_decay ** i_seed), min_mutation_power)
        result += (
            multiplier * torch.as_tensor(RandomState(seed).randn(*(like.shape)), dtype=like.dtype, device=like.device)
        )

    return result

Helper function to generate a random seed integer:

def sample_seed() -> int:
    return random.randint(0, (2 ** 32) - 1)

Observation normalization. Below, we have helper functions that will generate observation data for the reinforcement learning environment at hand. The observation data will be used for normalizing the observations before passing them to the policy neural network.

def env_name_to_file_name(env_name: str) -> str:
    """
    Convert the gymnasium environment ID to a more file-name-friendly counterpart.

    The character ':' in the input string will be replaced with '__colon__'.
    Similarly, the character '/' in the input string will be replaced with '__slash__'.

    Args:
        env_name: gymnasium environment ID
    Returns:
        File-name-friendly counterpart of the input string.
    """
    result = env_name
    result = result.replace(":", "__colon__")
    result = result.replace("/", "__slash__")
    return result
def create_obs_data(
    *,
    env_name: str,
    num_timesteps: int,
    report_interval: Union[int, float] = 5,
    seed: int = 0,
) -> tuple:
    """
    Create observation normalization data with the help of random actions.

    This function creates a gymnasium environment from the given `env_name`.
    Then, it keeps sending random actions to this environment, and collects stats from the observations.

    Args:
        env_name: ID of the gymnasium environment
        num_timesteps: For how many timesteps will the function operate on the environment
        report_interval: Time interval, in seconds, for reporting the status
        seed: A seed that will be used for regulating the randomness of both the environment
            and of the random actions.
    Returns:
        A tuple of the form `(mean, stdev)`, where `mean` is the elementwise mean of the observation vectors,
        and `stdev` is the elementwise standard deviation of the observation vectors.
    """
    print("Creating observation data for", env_name)

    class accumulated:
        sum: Optional[np.ndarray] = None
        sum_of_squares: Optional[np.ndarray] = None
        count: int = 0

    def accumulate(obs: np.ndarray):
        if accumulated.sum is None:
            accumulated.sum = obs.copy()
        else:
            accumulated.sum += obs

        squared = obs ** 2
        if accumulated.sum_of_squares is None:
            accumulated.sum_of_squares = squared
        else:
            accumulated.sum_of_squares += squared

        accumulated.count += 1

    rndgen = np.random.RandomState(seed)

    env = gym.make(env_name)
    assert isinstance(env.action_space, gym.spaces.Box), "Can only work with Box action spaces"

    def reset_env() -> tuple:
        return env.reset(seed=rndgen.randint(2 ** 32))

    action_gap = env.action_space.high - env.action_space.low
    def sample_action() -> np.ndarray:
        return (rndgen.rand(*(env.action_space.shape)) * action_gap) + env.action_space.low

    observation, _ = reset_env()
    accumulate(observation)

    last_report_time = datetime.now()

    for t in range(num_timesteps):
        action = sample_action()
        observation, _, terminated, truncated, _ = env.step(action)
        accumulate(observation)

        done = terminated | truncated
        if done:
            observation, info = reset_env()
            accumulate(observation)

        tnow = datetime.now()
        if (tnow - last_report_time).total_seconds() > report_interval:
            print("Number of timesteps:", t, "/", num_timesteps)
            last_report_time = tnow

    E_x = accumulated.sum / accumulated.count
    E_x2 = accumulated.sum_of_squares / accumulated.count

    mean = E_x
    variance = np.maximum(E_x2 - ((E_x) ** 2), 1e-2)
    stdev = np.sqrt(variance)

    print("Done.")

    return mean, stdev
def get_obs_data(env_name: str, num_timesteps: int = 50000, seed: int = 0) -> tuple:
    """
    Generate observation normalization data for the gymnasium environment whose name is given.

    If such normalization data was already generated and saved into a pickle file, that pickle file will be loaded.
    Otherwise, new normalization data will be generated and saved into a new pickle file.

    Args:
        env_name: ID of the gymnasium environment
        num_timesteps: For how many timesteps will the observation collector operate on the environment
        seed: A seed that will be used for regulating the randomness of both the environment
            and of the random actions.
    Returns:
        A tuple of the form `(mean, stdev)`, where `mean` is the elementwise mean of the observation vectors,
        and `stdev` is the elementwise standard deviation of the observation vectors.
    """
    num_timesteps = int(num_timesteps)
    envfname = env_name_to_file_name(env_name)
    fname = f"obs_seed{seed}_t{num_timesteps}_{envfname}.pickle"
    if os.path.isfile(fname):
        with open(fname, "rb") as f:
            return pickle.load(f)
    else:
        obsdata = create_obs_data(env_name=env_name, num_timesteps=num_timesteps, seed=seed)
        with open(fname, "wb") as f:
            pickle.dump(obsdata, f)
        return obsdata

Problem definition. Below is the problem definition for the considered reinforcement learning task. We are defining the problem as a subclass of evotorch.Problem, so that we will be able to use ray-based parallelization capabilities of the base Problem class.

class MyRLProblem(Problem):
    def __init__(
        self,
        *,
        env_name: str,
        obs_mean: Optional[np.ndarray] = None,
        obs_stdev: Optional[np.ndarray] = None,
        mutation_power: float = 0.5,
        mutation_decay: float = 0.9,
        min_mutation_power: float = 0.05,
        hidden_sizes: tuple = (64,),
        bin: Optional[float] = 0.2,
        num_episodes: int = 4,
        episode_length: Optional[int] = None,
        decrease_rewards_by: Optional[float] = 1.0,
        num_actors: Optional[Union[int, str]] = "max"
    ):
        super().__init__(
            objective_sense="max",
            dtype=object,
            num_actors=num_actors,
        )
        self._env_name = str(env_name)
        self._env = None
        self._hidden_sizes = [int(hidden_size) for hidden_size in hidden_sizes]
        self._policy = None

        self._obs_mean = None if obs_mean is None else np.asarray(obs_mean).astype("float32")
        self._obs_stdev = None if obs_mean is None else np.asarray(obs_stdev).astype("float32")
        self._mutation_power = float(mutation_power)
        self._mutation_decay = float(mutation_decay)
        self._min_mutation_power = float(min_mutation_power)
        self._bin = None if bin is None else float(bin)
        self._num_episodes = int(num_episodes)
        self._episode_length = None if episode_length is None else int(episode_length)
        self._decrease_rewards_by = None if decrease_rewards_by is None else float(decrease_rewards_by)

    def _get_policy(self) -> nn.Module:
        env = self._get_env()

        if not isinstance(env.observation_space, gym.spaces.Box):
            raise TypeError(
                f"Only Box-typed environments are supported. Encountered observation space is {env.observation_space}"
            )

        [obslen] = env.observation_space.shape
        if isinstance(env.action_space, gym.spaces.Box):
            [actlen] = env.action_space.shape
        elif isinstance(env.action_space, gym.spaces.Discrete):
            actlen = env.action_space.n
        else:
            raise TypeError(f"Unrecognized action space: {env.action_space}")

        all_sizes = [obslen]
        all_sizes.extend(self._hidden_sizes)
        all_sizes.append(actlen)

        last_size_index = len(all_sizes) - 1

        modules = []
        for i in range(1, len(all_sizes)):
            modules.append(nn.Linear(all_sizes[i - 1], all_sizes[i]))
            if i < last_size_index:
                modules.append(nn.Tanh())

        return nn.Sequential(*modules)

    def _get_env(self, visualize: bool = False) -> gym.Env:
        if visualize:
            return gym.make(self._env_name, render_mode="human")

        if self._env is None:
            self._env = gym.make(self._env_name)
        return self._env

    def _generate_single_solution(self) -> dict:
        policy = self._get_policy()
        result = {}
        for param_name, params in policy.named_parameters():
            result[param_name] = [sample_seed()]
        return result

    def generate_values(self, n: int) -> ObjectArray:
        return make_tensor([self._generate_single_solution() for _ in range(n)], dtype=object)

    def run_solution(
        self,
        x: Union[Mapping, Solution],
        *,
        num_episodes: Optional[int] = None,
        visualize: bool = False
    ) -> float:
        if num_episodes is None:
            num_episodes = self._num_episodes

        if isinstance(x, Mapping):
            sln = x
        elif isinstance(x, Solution):
            sln = x.values
        else:
            raise TypeError(f"Expected a Mapping or a Solution, but got {repr(x)}")

        policy = self._get_policy()

        params = {}
        for param_name, param_values in policy.named_parameters():
            param_seeds = sln[param_name]
            params[param_name] = make_tensor_from_seeds(
                param_values,
                param_seeds,
                mutation_power=self._mutation_power,
                mutation_decay=self._mutation_decay,
                min_mutation_power=self._mutation_power,
            )

        env = self._get_env(visualize=visualize)

        def use_policy(policy_input: np.ndarray) -> Union[int, np.ndarray]:
            if (self._obs_mean is not None) and (self._obs_stdev is not None):
                policy_input = policy_input - self._obs_mean
                policy_input = policy_input / self._obs_stdev

            result = functional_call(policy, params, torch.as_tensor(policy_input, dtype=torch.float32)).numpy()

            if isinstance(env.action_space, gym.spaces.Box):
                if self._bin is not None:
                    result = np.sign(result) * self._bin
                result = np.clip(result, env.action_space.low, env.action_space.high)
            elif isinstance(env.action_space, gym.spaces.Discrete):
                result = int(np.argmax(result))
            else:
                raise TypeError(f"Unrecognized action space: {repr(env.action_space)}")

            return result

        cumulative_reward = 0.0

        for _ in range(num_episodes):
            timestep = 0
            observation, info = env.reset()
            while True:
                action = use_policy(observation)
                observation, reward, done1, done2, _ = env.step(action)
                timestep += 1
                if (self._decrease_rewards_by is not None) and (not visualize):
                    reward = reward - self._decrease_rewards_by
                cumulative_reward += reward
                if (
                    done1
                    or done2
                    or (
                        (not visualize)
                        and (self._episode_length is not None)
                        and (timestep >= self._episode_length)
                    )
                ):
                    break

        return cumulative_reward / num_episodes

    def visualize(self, x: Union[Solution, Mapping]) -> float:
        return self.run_solution(x, num_episodes=1, visualize=True)

    def _evaluate(self, x: Solution):
        x.set_evaluation(self.run_solution(x))

We now define our mutation and cross-over operators, via the functions mutate and cross_over. Since the solutions are expressed via dictionary-like objects, we use Mapping for type annotations.

def mutate(solution: Mapping) -> Mapping:
    from evotorch.tools import as_immutable

    solution = {k: list(v) for k, v in solution.items()}

    keys = list(solution.keys())
    chosen_key = random.choice(keys)
    solution[chosen_key].append(sample_seed())

    return as_immutable(solution)
def cross_over(parent1: Mapping, parent2: Mapping) -> tuple:
    from evotorch.tools import as_immutable

    keys = list(parent1.keys())

    child1 = {}
    child2 = {}
    for k in keys:
        p = random.random()
        if p < 0.5:
            child1[k] = parent1[k]
            child2[k] = parent2[k]
        else:
            child1[k] = parent2[k]
            child2[k] = parent1[k]

    return as_immutable(child1), as_immutable(child2)

ID of the considered reinforcement learning task:

ENV_NAME = "Ant-v4"

Generate or load observation data for the considered reinforcement learning environment:

env_obs_mean, env_obs_stdev = get_obs_data(ENV_NAME)
env_obs_mean, env_obs_stdev

Instantiate the problem object:

problem = MyRLProblem(
    env_name=ENV_NAME,
    decrease_rewards_by=1.0,
    episode_length=250,
    bin=0.15,
    obs_mean=env_obs_mean,
    obs_stdev=env_obs_stdev,
)

problem

Out of the instantiated problem object, we make a callable evaluator named f. The resulting object f can be used as a fitness function.

f = problem.make_callable_evaluator()
f

Helper function for converting a real number to a string. We will use this while reporting the status of the evolution.

def number_to_str(x) -> str:
    return "%.2f" % float(x)

Hyperparameters and constants for the evolutionary algorithm:

popsize = 16
tournament_size = 4
objective_sense = problem.objective_sense
num_generations = 100

We now prepare the initial population. When we are dealing with non-numeric solutions, a population is represented via evotorch.tools.ObjectArray, instead of torch.Tensor.

population = problem.generate_values(popsize)
population

Evaluate the fitnesses of the solutions within the initial population:

evals = f(population)
evals

Main loop of the evolutionary search:

for generation in range(1, 1 + num_generations):
    t_begin = datetime.now()

    # Apply tournament selection on the population
    parent1_indices, parent2_indices = func_ops.tournament(
        population,
        evals,
        tournament_size=tournament_size,
        num_tournaments=popsize,
        split_results=True,
        return_indices=True,
        objective_sense=objective_sense,
    )

    # The results of the tournament selection are stored within the integer
    # tensors `parent1_indices` and `parent2_indices`.
    # The pairs of solutions for the cross-over operator are:
    # - `population[parent1_indices[0]]` and `population[parent2_indices[0]]`,
    # - `population[parent1_indices[1]]` and `population[parent2_indices[1]]`,
    # - `population[parent1_indices[2]]` and `population[parent2_indices[2]]`,
    # - and so on...
    num_pairs = len(parent1_indices)
    children = []
    for i in range(num_pairs):
        parent1_index = int(parent1_indices[i])
        parent2_index = int(parent2_indices[i])
        child1, child2 = cross_over(population[parent1_index], population[parent2_index])
        child1 = mutate(child1)
        child2 = mutate(child2)
        children.extend([child1, child2])

    # With the help of the function `evotorch.tools.make_tensor(...)`,
    # we convert the list of child solutions to an ObjectArray, so that
    # `children` can be treated as a population of solutions.
    children = make_tensor(children, dtype=object)

    # Combine the original population with the population of children,
    # forming an extended population.
    extended_population = func_ops.combine(population, children)

    # Evaluate all the solutions within the extended population.
    extended_evals = f(extended_population)

    # Take the best `popsize` number of solutions from the extended population.
    population, evals = func_ops.take_best(
        extended_population, extended_evals, popsize, objective_sense=objective_sense
    )

    t_end = datetime.now()
    time_taken = (t_end - t_begin).total_seconds()

    # Report the status of the evolutionary search.
    print(
        "Generation:", generation,
        "  Mean eval:", number_to_str(evals.mean()),
        "  Pop best:", number_to_str(evals.max()),
        "  Time:", number_to_str(time_taken)
    )

Take the index of the best solution within the last population:

best_index = torch.argmax(evals)
best_index

Take the best solution within the last population:

best_params = population[best_index]
best_params

Visualize the gait of the population's best solution:

problem.visualize(best_params)

See this notebook on GitHub