Maintaining a batch of populations using the functional EvoTorch API¶
Motivation¶
EvoTorch already implements mechanisms for accelerating the evaluation of solutions in a population using ray actors and/or PyTorch vectorization, which is essential for obtaining results in reasonable time.
However, usually one also needs to run search algorithms multiple times to evaluate the effect of various hyperparameters (such as learning rate, mutation probability etc). This can take a lot of time, and needs additional effort to parallelize (e.g. on a cluster). The functional API of EvoTorch attempts to address this need by leveraging PyTorch's vmap() transform.
The main idea is that search algorithms in evotorch.algorithms.functional
are pure functional implementations, and thus can easily be transformed (using vmap()
) to operate on multiple populations stored as batches of populations (or batches of batches of populations, and so on). With the help of such implementations, one can run multiple searches in parallel starting from different initial populations that cover different regions of the search space. As another example, one can run multiple searches that each have a different initial population as well as a corresponding learning rate.
In this notebook, we demonstrate how a batch of populations, each originated from a different starting point, can be maintained so that different regions of the search space can be explored simultaneously. The key concepts covered are:
- the ask-and-tell interface provided by search algorithms in the evotorch.algorithms.functional
namespace.
- the @rowwise
decorator that makes it easier to write evaluation functions that can be automatically transformed using vmap()
.
- running multiple searches using vectorization simply by adding a batch dimension to arguments in the ask-and-tell interface and letting EvoTorch handle the rest.
We begin by importing the necessary libraries and defining some useful variables:
from evotorch.algorithms.functional import cem, cem_ask, cem_tell, pgpe, pgpe_ask, pgpe_tell
from evotorch.decorators import rowwise
from datetime import datetime
import torch
from math import pi
# Use a GPU to achieve speedups from vectorization if possible
device = "cuda" if torch.cuda.is_available() else "cpu"
# We will search for 1000-dimensional solution vectors
solution_length = 1000
Ask-and-tell¶
Next, we implement a simple optimization loop for the commonly used Rastrigin function using the ask-and-tell interface for the Cross Entropy Method. An important detail to note is that we are directly evaluating the full population using the evaluation function rastrigin()
so we need to implement it in a way that it operates on a population represented as a 2D Tensor.
def rastrigin(x: torch.Tensor) -> torch.Tensor:
n = x.shape[-1]
A = 10.0
return A * n + torch.sum((x ** 2) - (A * torch.cos(2 * pi * x)), dim=-1)
# Uniformly sample center_init in [-5.12, 5.12], the typical domain of the rastrigin function
center_init = (torch.rand(solution_length) * 2 - 1) * 5.12
# Set std_max_change to 0.1 for all solution dimensions
stdev_max_change = 0.1 * torch.ones(solution_length)
cem_state = cem(
# We want to minimize the evaluation results
objective_sense="min",
# `center_init` is the center point(s) of the initial search distribution(s).
center_init=center_init.to(device),
# The standard deviation of the initial search distribution.
stdev_init=10.0,
# We provide our batch of hyperparameter vectors as `stdev_max_change`.
stdev_max_change=stdev_max_change,
# Solutions belonging to the top half (top 50%) of the population(s)
# will be chosen as parents.
parenthood_ratio=0.5,
)
# We will run the evolutionary search for this many generations:
num_generations = 1500
# Interval (in seconds) for printing the status:
report_interval = 3
start_time = last_report_time = datetime.now()
for generation in range(1, 1 + num_generations):
# Get a population from the evolutionary algorithm
population = cem_ask(cem_state, popsize=500)
# Compute the fitnesses
fitnesses = rastrigin(population)
# Inform the evolutionary algorithm of the fitnesses and get its next state
cem_state = cem_tell(cem_state, population, fitnesses)
# If it is time to report, print the status
tnow = datetime.now()
if ((tnow - last_report_time).total_seconds() > report_interval) or (generation == num_generations):
print("generation:", generation, "mean fitnesses:", torch.mean(fitnesses, dim=-1))
last_report_time = tnow
print("time taken: ", (last_report_time - start_time).total_seconds(), "secs")
The @rowwise decorator¶
Next, we modify the code above so that multiple searches can be executed simultaneously taking advantage of PyTorch's vectorization capabilities. We modify the fitness function as follows:
@rowwise
def rastrigin(x: torch.Tensor) -> torch.Tensor:
[n] = x.shape
A = 10.0
return A * n + torch.sum((x ** 2) - (A * torch.cos(2 * pi * x)))
Notice how the fitness function above is decorated via @rowwise
. This decorator tells EvoTorch that the user has defined the function to operate on its argument x
as a vector (i.e. a 1-dimensional tensor). This makes it conceptually easier to implement the function and helps EvoTorch safely vmap()
it in order to apply it to populations or a batch of populations as needed. @rowwise
ensures that:
- if the argument
x
is indeed received as a 1-dimensional tensor, the function works as how it is defined; - if the argument
x
is received as a matrix (i.e. as a 2-dimensional tensor), the operations of the function are applied for each row of the matrix; - if the argument
x
is received as a tensor with 3 or more dimensions, the operations of the function are applied for each row of each matrix.
Thanks to this, the fitness function rastrigin
can be used as it is to evaluate a single solution (represented by a 1-dimensional tensor), a single population (represented by a 2-dimensional tensor), or a batch of populations (represented by a tensor with 3 or more dimensions).
Note: We don't have to use @rowwise
to implement our fitness function. Indeed, since our previous definition of rastrigin()
happens to handle any number of batch dimensions and return a fitness value for each vector, we can use it as-is for running a batch of multiple searches. However, writing the fitness function in such a general way can often be difficult and error-prone, so it is much more convenient to use @rowwise
.
Batched (vectorized) searches¶
Using the modified rastrigin function above, we are almost ready to run a batch of searches utilizing additional vectorization over the number of searches. We will run 4 searches in a batch:
For both functional cem
and functional pgpe
, the hyperparameter stdev_max_change
can be given as a scalar (which then will be expanded to a vector), or as a vector (which then will be used as it is), or as a batch of vectors (which will mean that for each batch item i
, the i
-th stdev_max_change
vector will be used).
Since we consider a batch of populations in this example, let us make a batch of starting points and stdev_max_change
vectors, meaning that each population will have its own different starting point and stdev_max_change
hyperparameter.
center_inits = ((torch.rand(batch_size, solution_length) * 2) - 1) * 5.12
# uniformly sample std_max_change between 0.01 and 0.2
stdev_max_changes = torch.linspace(0.01, 0.2, batch_size)[:, None].expand(-1, solution_length)
print(center_inits.shape, stdev_max_changes.shape)
Next we simply provide these to the CEM state initializer and execute CEM using the ask-and-tell interface exacty as before. Internally, EvoTorch will recognize the new batch dimension and appropriately vmap()
the fitness function for us!
cem_state = cem(
objective_sense="min",
# The batch of vectors `starting_points` is given as our `center_init`,
# that is, the center point(s) of the initial search distribution(s).
center_init=center_inits.to(device),
# The standard deviation of the initial search distribution(s).
stdev_init=10.0,
# We provide our batch of hyperparameter vectors as `stdev_max_change`.
stdev_max_change=stdev_max_changes,
parenthood_ratio=0.5,
)
start_time = last_report_time = datetime.now()
for generation in range(1, 1 + num_generations):
# Get a population from the evolutionary algorithm
population = cem_ask(cem_state, popsize=500)
# Compute the fitnesses
fitnesses = rastrigin(population)
# Inform the evolutionary algorithm of the fitnesses and get its next state
cem_state = cem_tell(cem_state, population, fitnesses)
# If it is time to report, print the status
tnow = datetime.now()
if ((tnow - last_report_time).total_seconds() > report_interval) or (generation == num_generations):
print("generation:", generation, "mean fitnesses:", torch.mean(fitnesses, dim=-1))
last_report_time = tnow
print("time taken: ", (last_report_time - start_time).total_seconds(), "secs")
If this notebook is executed on a GPU, the above batched search will take less time than batch_size
times the time taken by the single search above, particularly for larger values of batch_size
. Here are the center points found by CEM:
Another example¶
As another example, let us consider the functional pgpe
algorithm.
For pgpe
, center_learning_rate
is a hyperparameter which is expected as a scalar in the non-batched case.
If it is provided as a vector, this means that for each batch item i
, the i
-th value of the center_learning_rate
vector will be used.
Let us build a center_learning_rate
vector:
Now we prepare the first state of our pgpe
search:
pgpe_state = pgpe(
# We want to minimize the evaluation results.
objective_sense="min",
# The batch of vectors `starting_points` is given as our `center_init`,
# that is, the center point(s) of the initial search distribution(s).
center_init=center_inits.to(device),
# Standard deviation for the initial search distribution(s):
stdev_init=10.0,
# We provide our `center_learning_rate` batch here:
center_learning_rate=center_learning_rates,
# Learning rate for the standard deviation(s) of the search distribution(s):
stdev_learning_rate=0.1,
# We use the "centered" ranking where the worst solution is ranked -0.5,
# and the best solution is ranked +0.5:
ranking_method="centered",
# We use the ClipUp optimizer.
optimizer="clipup",
# Just like how we provide a batch of `center_learning_rate` values,
# we provide a batch of `max_speed` values for ClipUp:
optimizer_config={"max_speed": center_learning_rates * 2},
# Maximum relative change allowed for standard deviation(s) of the
# search distribution(s):
stdev_max_change=0.2,
)
Below is the main loop of the evolutionary search.
# We will run the evolutionary search for this many generations:
num_generations = 1500
start_time = last_report_time = datetime.now()
for generation in range(1, 1 + num_generations):
# Get a population from the evolutionary algorithm
population = pgpe_ask(pgpe_state, popsize=500)
# Compute the fitnesses
fitnesses = rastrigin(population)
# Inform the evolutionary algorithm of the fitnesses and get its next state
pgpe_state = pgpe_tell(pgpe_state, population, fitnesses)
# If it is time to report, print the status
tnow = datetime.now()
if ((tnow - last_report_time).total_seconds() > report_interval) or (generation == num_generations):
print("generation:", generation, "mean fitnesses:", torch.mean(fitnesses, dim=-1))
last_report_time = tnow
print("time taken: ", (last_report_time - start_time).total_seconds(), "secs")
Here are the center points found by pgpe
: