Skip to content

evotorch.decorators

Module defining decorators for evotorch.

pass_info(fn_or_class)

Decorates a callable so that the neuroevolution problem class (e.g. GymNE) will pass information regarding the task at hand, in the form of keyword arguments.

For example, in the case of GymNE or VecGymNE, the passed information would include dimensions of the observation and action spaces.

Examples:

@pass_info
class MyModule(nn.Module):
    def __init__(self, obs_length: int, act_length: int, **kwargs):
        # Because MyModule is decorated with @pass_info, it receives
        # keyword arguments related to the environment "CartPole-v0",
        # including obs_length and act_length.
        ...


problem = GymNE(
    "CartPole-v0",
    network=MyModule,
    ...,
)

Parameters:

Name Type Description Default
fn_or_class Callable

Function or class to decorate

required

Returns:

Type Description
Callable

Decorated function or class

Source code in evotorch/decorators.py
def pass_info(fn_or_class: Callable) -> Callable:
    """
    Decorates a callable so that the neuroevolution problem class (e.g. GymNE) will
    pass information regarding the task at hand, in the form of keyword arguments.

    For example, in the case of [GymNE][evotorch.neuroevolution.GymNE] or
    [VecGymNE][evotorch.neuroevolution.VecGymNE], the passed information would
    include dimensions of the observation and action spaces.

    Example:
        ```python
        @pass_info
        class MyModule(nn.Module):
            def __init__(self, obs_length: int, act_length: int, **kwargs):
                # Because MyModule is decorated with @pass_info, it receives
                # keyword arguments related to the environment "CartPole-v0",
                # including obs_length and act_length.
                ...


        problem = GymNE(
            "CartPole-v0",
            network=MyModule,
            ...,
        )
        ```

    Args:
        fn_or_class (Callable): Function or class to decorate

    Returns:
        Callable: Decorated function or class
    """
    setattr(fn_or_class, "__evotorch_pass_info__", True)
    return fn_or_class