Skip to content

Genetic Programming using EvoTorch

In this example, we perform genetic programming where the fitness function is a stack-based expression interpreter. While this notebook does not represent the state-of-the-art in the field of genetic programming, it can be a useful example as it demonstrates the following:

  • How to use the vectorized data structures provided by EvoTorch
  • How to use GeneticAlgorithm to solve discrete optimization problems
  • How to define problem-specific mutation operators
from evotorch import Problem, SolutionBatch
from evotorch.operators import TwoPointCrossOver
from evotorch.algorithms import GeneticAlgorithm
from evotorch.logging import StdOutLogger
from evotorch.tools.structures import CList
from typing import Callable, Iterable, Optional, Union
from collections import namedtuple
import torch
import math
import os

Below are some additional functions that we wish to use in our genetic programming example.

class AdditionalTorchFunctions:
    @staticmethod
    def _forbid_zero(x: torch.Tensor) -> torch.Tensor:
        """
        Move x away from 0 if its absolute value is less then 1e-4.
        """
        tolerance = 1e-4
        close_to_zero_from_pos = (x >= 0) & (x < tolerance)
        close_to_zero_from_neg = (x < 0) & (x > -tolerance)
        result = x.clone()
        result[close_to_zero_from_pos] = tolerance
        result[close_to_zero_from_neg] = -tolerance
        return result

    @classmethod
    def unary_div(cls, x: torch.Tensor) -> torch.Tensor:
        """
        Unary division with protection against division-by-zero.
        If x is not near zero, the result will be 1/x.
        If x is near zero, then it will first be moved away from zero.
        """
        return 1 / cls._forbid_zero(x)

    @classmethod
    def binary_div(cls, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
        """
        Binary division with protection against division-by-zero.
        If b is not near zero, the result will be a/b.
        If b is near zero, then it will first be moved away from zero.
        """
        return a / cls._forbid_zero(b)

Now, we present the definition of an Instruction. An Instruction is a callable object which has access to a read-only input memory and to a runtime stack. Depending on how it was initialized, an Instruction can pull its arguments from the runtime stack or from the input memory. After processing its arguments, the Instruction will push its result onto the runtime stack.

_PopResult = namedtuple("_PopResult", ["tensor_a", "pop_mask"])
_PopPairResult = namedtuple("_PopPairResult", ["tensor_a", "tensor_b", "pop_mask"])


class Instruction:
    def __init__(
        self,
        *,
        inputs: torch.Tensor,
        stack: CList,
        arity: int,
        function: Optional[Callable] = None,
        input_slot: Optional[int] = None,
        operation: Optional[str] = None,
    ):
        [batch_size] = stack.batch_shape
        inputs_batch_size, input_size = inputs.shape
        assert inputs_batch_size == batch_size

        self.input_size = input_size
        self.stack = stack
        self.inputs = inputs
        self.arity = int(arity)

        self.function = None
        self.input_slot = None
        self.operation = None

        instr_definitions = 0

        if function is not None:
            self.function = function
            instr_definitions += 1

        if input_slot is not None:
            assert self.arity == 0
            self.input_slot = input_slot
            instr_definitions += 1

        if operation is not None:
            assert self.arity == 0
            assert operation in ("pass", "swap", "duplicate")
            self.operation = operation
            instr_definitions += 1

        assert instr_definitions == 1, "Please specify only one of these: `function`, `input_slot`, or `operation`."        
        assert self.arity in (0, 1, 2)

    def _pop(self, where: torch.Tensor) -> _PopResult:
        suitable = self.stack.length >= 1
        where = where & suitable
        return _PopResult(tensor_a=self.stack.pop_(where=where), pop_mask=where)

    def _pop_pair(self, where: torch.Tensor) -> _PopPairResult:
        suitable = self.stack.length >= 2
        where = where & suitable
        b = self.stack.pop_(where=where)
        a = self.stack.pop_(where=where)
        return _PopPairResult(tensor_a=a, tensor_b=b, pop_mask=where)

    def _push(self, x: torch.Tensor, where: torch.Tensor):
        self.stack.push_(x, where=where)

    def _push_input(self, input_slot: int, where: torch.Tensor):
        input_values = self.inputs[:, input_slot]
        self.stack.push_(input_values, where=where)

    def __call__(self, where: torch.Tensor):
        if self.function is not None:
            fn = self.function
            arity = self.arity

            if arity == 0:
                self._push(fn(), where=where)
            elif arity == 1:
                a, where = self._pop(where=where)
                self._push(fn(a), where=where)
            elif arity == 2:
                a, b, where = self._pop_pair(where=where)
                self._push(fn(a, b), where=where)
            else:
                assert False

        if self.input_slot is not None:
            self._push_input(self.input_slot, where=where)

        if self.operation is not None:
            if self.operation == "pass":
                pass
            elif self.operation == "swap":
                a, b, where = self._pop_pair(where=where)
                self._push(b, where=where)
                self._push(a, where=where)
            elif self.operation == "duplicate":
                a, where = self._pop(where=where)
                self._push(a, where=where)
                self._push(a, where=where)
            else:
                assert False, f"unknown operation: {operation}"

    def __repr__(self) -> str:
        result = []

        def puts(*xs: str):
            for x in xs:
                result.append(str(x))

        puts(type(self).__name__, "(")

        if self.function is not None:
            if hasattr(self.function, "__name__"):
                fn_name = self.function.__name__
            else:
                fn_name = repr(self.function)
            puts("function=", fn_name)

        if self.input_slot is not None:
            puts("input_slot=", self.input_slot)

        if self.operation is not None:
            puts("operation=", repr(self.operation))

        puts(", arity=", self.arity)
        puts(")")
        return "".join(result)

Now we define a stack-based Interpreter. This Interpreter supports batching, and can work in a vectorized manner. According to the batching scheme of this Interpreter, each program in the batch has its own input, and produces its own output.

class Interpreter:
    def __init__(
        self,
        *,
        max_stack_length: int,
        batch_size: int,
        input_size: int,
        unary_ops: Iterable,
        binary_ops: Iterable,
        pass_means_terminate: bool = True,
        device: Optional[Union[str, torch.device]] = None,
    ):
        if device is None:
            device = torch.device("cpu")
        else:
            device = torch.device(device)

        self._batch_size = int(batch_size)
        self._input_size = int(input_size)
        self._max_stack_length = int(max_stack_length)

        self._stack = CList(
            max_length=self._max_stack_length,
            batch_size=self._batch_size,
            dtype=torch.float32,
            device=device,
            verify=False,
        )
        self._inputs = torch.zeros(self._batch_size, self._input_size, dtype=torch.float32, device=device)

        self._instructions = []
        self._pass_means_terminate = bool(pass_means_terminate)

        for operation in ("pass", "swap", "duplicate"):
            self._instructions.append(
                Instruction(
                    inputs=self._inputs,
                    stack=self._stack,
                    arity=0,
                    operation=operation,
                )
            )

        for i_input in range(self._input_size):
            self._instructions.append(
                Instruction(
                    inputs=self._inputs,
                    stack=self._stack,
                    arity=0,
                    input_slot=i_input,
                )
            )

        for unary_op in unary_ops:
            self._instructions.append(
                Instruction(
                    inputs=self._inputs,
                    stack=self._stack,
                    arity=1,
                    function=unary_op,
                )
            )

        for binary_op in binary_ops:
            self._instructions.append(
                Instruction(
                    inputs=self._inputs,
                    stack=self._stack,
                    arity=2,
                    function=binary_op,
                )
            )

    @property
    def instructions(self) -> list:
        return self._instructions

    @property
    def stack(self) -> CList:
        return self._stack

    def run(self, program_batch: torch.Tensor, input_batch: torch.Tensor) -> torch.Tensor:
        self._stack.clear()
        program_batch = torch.as_tensor(program_batch, dtype=torch.int64, device=self._stack.device)
        batch_size, program_length = program_batch.shape
        assert batch_size == self._batch_size

        if self._pass_means_terminate:
            program_running = torch.ones(batch_size, dtype=torch.bool, device=self._stack.device)
        else:
            program_running = None

        self._inputs[:] = input_batch

        for t in range(program_length):
            instruction_codes = program_batch[:, t]

            if self._pass_means_terminate:
                program_running = program_running & (instruction_codes != 0)

            for i_instruction in range(1, len(self._instructions)):
                instruction_codes_match = (instruction_codes == i_instruction)
                if self._pass_means_terminate:
                    instruction_codes_match = instruction_codes_match & program_running

                self._instructions[i_instruction](where=instruction_codes_match)

        return self._stack.get(-1, default=0.0)

Above, we have defined a batched interpreter where each program works on its own input and produces its own output. However, when doing symbolic regression, the most common scheme is to have fixed batch of inputs that is to be used by each program within the program batch. To be compatible with this scheme, we now define an InterpreterWithInputBatch, which, upon receiving a batch of inputs and a separate batch of programs, arranges them in this manner:

input0 -> program0 -> output0,0
input1 -> program0 -> output1,0
input2 -> program0 -> output2,0
  :         :           :
inputN -> program0 -> outputN,0
input0 -> program1 -> output0,1
input1 -> program1 -> output1,1
input2 -> program1 -> output2,1
  :         :           :
inputN -> program1 -> outputN,1
  :         :           :
  :         :           :
input0 -> programM -> output0,M
input1 -> programM -> output1,M
input2 -> programM -> output2,M
  :         :           :
inputN -> programM -> outputN,M

After creating this flattened batch, InterpreterWithInputBatch passes it to its own internal Interpreter. InterpreterWithInputBatch also defines a method for computing the mean squared error which compares the programs' outputs against the desired outputs.

class InterpreterWithInputBatch:
    def __init__(
        self,
        *,
        max_stack_length: int,
        program_batch_size: int,
        input_size: int,
        input_batch_size: int,
        unary_ops: Iterable,
        binary_ops: Iterable,
        pass_means_terminate: bool = True,
        device: Optional[Union[str, torch.device]] = None,
    ):
        self._program_batch_size = int(program_batch_size)
        self._input_batch_size = int(input_batch_size)
        self._input_size = int(input_size)
        self._batch_size = self._program_batch_size * self._input_batch_size

        self._interpreter = Interpreter(
            max_stack_length=max_stack_length,
            batch_size=self._batch_size,
            input_size=self._input_size,
            unary_ops=unary_ops,
            binary_ops=binary_ops,
            pass_means_terminate=pass_means_terminate,
            device=device,
        )

    def run(self, program_batch: torch.Tensor, input_batch: torch.Tensor) -> torch.Tensor:
        programs = torch.repeat_interleave(program_batch, self._input_batch_size, dim=0)
        inputs = (
            input_batch
            .expand(self._program_batch_size, self._input_batch_size, self._input_size)
            .reshape(self._batch_size, self._input_size)
        )
        return (
            self._interpreter
            .run(programs, inputs)
            .reshape(self._program_batch_size, self._input_batch_size)
        )

    def compute_mean_squared_error(
        self,
        program_batch: torch.Tensor,
        input_batch: torch.Tensor,
        desired_output_batch: torch.Tensor
    ) -> torch.Tensor:
        output = self.run(program_batch, input_batch)
        return torch.mean((output - desired_output_batch) ** 2, dim=-1)

    @property
    def stack(self) -> CList:
        return self._interpreter.stack

    @property
    def instructions(self) -> list:
        return self._interpreter.instructions

    @property
    def program_batch_size(self) -> int:
        return self._program_batch_size

Now that we have our InterpreterWithInputBatch, we can define a Problem class where the goal is to minimize this mean squared error.

class ProgramSynthesisProblem(Problem):
    def __init__(
        self,
        unary_ops: Iterable,
        binary_ops: Iterable,
        inputs: Iterable,
        outputs: Iterable,
        program_length: int,
        pass_means_terminate: bool = True,
        device: Optional[Union[str, torch.device]] = None,
        num_actors: Optional[Union[str, int]] = None,
    ):
        if device is None:
            device = torch.device("cpu")
        else:
            device = torch.device(device)

        self._program_length = int(program_length)
        self._inputs = torch.as_tensor(inputs, dtype=torch.float32, device=device)
        self._outputs = torch.as_tensor(outputs, dtype=torch.float32, device=device)

        self._input_batch_size, self._input_size = self._inputs.shape
        [output_batch_size] = self._outputs.shape
        assert output_batch_size == self._input_batch_size

        self._unary_ops = list(unary_ops)
        self._binary_ops = list(binary_ops)
        self._pass_means_terminate = pass_means_terminate

        self._interpreter: Optional[InterpreterWithInputBatch] = None
        num_instructions = len(self._get_interpreter(1).instructions)

        super().__init__(
            objective_sense="min",
            solution_length=self._program_length,
            dtype=torch.int64,
            bounds=(0, num_instructions - 1),
            device=device,
            #num_actors=num_actors,
            store_solution_stats=True,
        )

    def _get_interpreter(self, num_programs: int) -> InterpreterWithInputBatch:
        if (self._interpreter is None) or (num_programs > self._interpreter.program_batch_size):
            self._interpreter = InterpreterWithInputBatch(
                max_stack_length=self._program_length,
                program_batch_size=num_programs,
                input_size=self._input_size,
                input_batch_size=self._input_batch_size,
                unary_ops=self._unary_ops,
                binary_ops=self._binary_ops,
                pass_means_terminate=self._pass_means_terminate,
                device=self._inputs.device,
            )
        return self._interpreter

    def _evaluate_batch(self, batch: SolutionBatch):
        num_programs = len(batch)
        interpreter = self._get_interpreter(num_programs)

        if num_programs < interpreter.program_batch_size:
            programs = torch.zeros(
                (interpreter.program_batch_size, self.solution_length),
                dtype=torch.int64,
                device=interpreter.stack.device
            )
            programs[:num_programs, :] = batch.values
        else:
            programs = batch.values

        batch.set_evals(interpreter.compute_mean_squared_error(programs, self._inputs, self._outputs)[:num_programs])

    @property
    def instructions(self) -> list:
        interpreter = self._get_interpreter(1)
        return interpreter.instructions

    @property
    def instruction_dict(self) -> dict:
        result = {}
        for i_instruction, instruction in enumerate(self.instructions):
            result[i_instruction] = instruction
        return result

We now define a target function (the function whose definition will be searched for by our evolutionary algorithm). In the case of our example, we are searching for this function:

\[ \frac{x + y}{cos(x)} + sin(y) \]
def target_function(inputs: torch.Tensor) -> torch.Tensor:
    x = inputs[:, 0]
    y = inputs[:, 1]
    return AdditionalTorchFunctions.binary_div(x + y, torch.cos(x)) + torch.sin(y)

Below, we produce a deterministic input set, and then, using the target function, we obtain our target outputs.

inputs = []
input_values = [-5, -3, -1, 1, 3, 5]
for x in input_values:
    for y in input_values:
        inputs.append([x, y])
inputs = torch.as_tensor(inputs, dtype=torch.float32)
outputs = target_function(inputs)

inputs, outputs
inputs.shape, outputs.shape
device = "cpu"  # change this to e.g. "cuda:0" for exploiting the GPU

problem = ProgramSynthesisProblem(
    inputs=inputs,
    outputs=outputs,
    unary_ops=[torch.neg, torch.sin, torch.cos, AdditionalTorchFunctions.unary_div],
    binary_ops=[torch.add, torch.sub, torch.mul, AdditionalTorchFunctions.binary_div],
    program_length=20,
    device=device,
)

problem

Below is a simple mutation function which changes each symbol with a probability of 10%

def mutate_programs(programs: torch.Tensor) -> torch.Tensor:
    num_instructions = len(problem.instructions)
    mutate = torch.rand(programs.shape, device=programs.device) < 0.1
    num_mutations = int(torch.count_nonzero(mutate))
    result = programs.clone()
    mutated = torch.randint(0, num_instructions, (num_mutations,), device=programs.device)
    result[mutate] = mutated
    return result

Now we instantiate our genetic algorithm.

ga = GeneticAlgorithm(
    problem,
    operators=[TwoPointCrossOver(problem, tournament_size=4), mutate_programs],
    re_evaluate=False,
    popsize=5000,
)
ga
StdOutLogger(ga)
ga.run(50)

Below is the best solution encountered so far, hopefully with its evaluation result expressing a near-zero error value.

best_solution = ga.status["best"]
best_solution

The program reported above can be analyzed with the help of this instruction set:

problem.instruction_dict

See this notebook on GitHub