Skip to content

statefulmodule

StatefulModule (Module)

A wrapper that provides a stateful interface for recurrent torch modules.

If the torch module to be wrapped is non-recurrent and its forward method has a single input (the input tensor) and a single output (the output tensor), then this wrapper module acts as a no-op wrapper.

If the torch module to be wrapped is recurrent and its forward method has two inputs (the input tensor and an optional second argument for the hidden state) and two outputs (the output tensor and the new hidden state), then this wrapper brings a new forward-passing interface. In this new interface, the forward method has a single input (the input tensor) and a single output (the output tensor). The hidden states, instead of being explicitly requested via a second argument and returned as a second result, are stored and used by the wrapper. When a new series of inputs is to be used, one has to call the reset() method of this wrapper.

Source code in evotorch/neuroevolution/net/statefulmodule.py
class StatefulModule(nn.Module):
    """
    A wrapper that provides a stateful interface for recurrent torch modules.

    If the torch module to be wrapped is non-recurrent and its forward method
    has a single input (the input tensor) and a single output (the output
    tensor), then this wrapper module acts as a no-op wrapper.

    If the torch module to be wrapped is recurrent and its forward method has
    two inputs (the input tensor and an optional second argument for the hidden
    state) and two outputs (the output tensor and the new hidden state), then
    this wrapper brings a new forward-passing interface. In this new interface,
    the forward method has a single input (the input tensor) and a single
    output (the output tensor). The hidden states, instead of being
    explicitly requested via a second argument and returned as a second
    result, are stored and used by the wrapper.
    When a new series of inputs is to be used, one has to call the `reset()`
    method of this wrapper.
    """

    def __init__(self, wrapped_module: nn.Module):
        """
        `__init__(...)`: Initialize the StatefulModule.

        Args:
            wrapped_module: The `torch.nn.Module` instance to wrap.
        """
        super().__init__()

        # Declare the variable that will store the hidden state of wrapped_module, if any.
        self._hidden: Any = None

        # Store the module that is wrapped.
        self.wrapped_module = wrapped_module

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self._hidden is None:
            # If there is no stored hidden state, then only pass the input tensor to the wrapped module.
            out = self.wrapped_module(x)
        else:
            # If there is a hidden state saved from the previous call to this `forward(...)` method, then pass the
            # input tensor and this stored hidden state.
            out = self.wrapped_module(x, self._hidden)

        if isinstance(out, tuple):
            # If the result of the wrapped module is a tuple, then we assume that the wrapped module returned an
            # output tensor and a hidden state. We assume the first element of this tuple as the output tensor,
            # and the second element as the new hidden state.
            # We set the variable y to the output tensor, and we store the new hidden state via the attribute
            # `_hidden`.
            y, self._hidden = out
        else:
            # If the result of the wrapped module is not a tuple, then we assume that the wrapped module returned
            # only the output tensor. We set the variable y to the output tensor, and set the attribute `_hidden`
            # as None to indicate that there was no hidden state received.
            y = out
            self._hidden = None

        # We return y, which stores the output received by the wrapped module.
        return y

    def reset(self):
        """
        Reset the hidden state, if any.
        """
        self._hidden = None

__init__(self, wrapped_module) special

__init__(...): Initialize the StatefulModule.

Parameters:

Name Type Description Default
wrapped_module Module

The torch.nn.Module instance to wrap.

required
Source code in evotorch/neuroevolution/net/statefulmodule.py
def __init__(self, wrapped_module: nn.Module):
    """
    `__init__(...)`: Initialize the StatefulModule.

    Args:
        wrapped_module: The `torch.nn.Module` instance to wrap.
    """
    super().__init__()

    # Declare the variable that will store the hidden state of wrapped_module, if any.
    self._hidden: Any = None

    # Store the module that is wrapped.
    self.wrapped_module = wrapped_module

forward(self, x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in evotorch/neuroevolution/net/statefulmodule.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    if self._hidden is None:
        # If there is no stored hidden state, then only pass the input tensor to the wrapped module.
        out = self.wrapped_module(x)
    else:
        # If there is a hidden state saved from the previous call to this `forward(...)` method, then pass the
        # input tensor and this stored hidden state.
        out = self.wrapped_module(x, self._hidden)

    if isinstance(out, tuple):
        # If the result of the wrapped module is a tuple, then we assume that the wrapped module returned an
        # output tensor and a hidden state. We assume the first element of this tuple as the output tensor,
        # and the second element as the new hidden state.
        # We set the variable y to the output tensor, and we store the new hidden state via the attribute
        # `_hidden`.
        y, self._hidden = out
    else:
        # If the result of the wrapped module is not a tuple, then we assume that the wrapped module returned
        # only the output tensor. We set the variable y to the output tensor, and set the attribute `_hidden`
        # as None to indicate that there was no hidden state received.
        y = out
        self._hidden = None

    # We return y, which stores the output received by the wrapped module.
    return y

reset(self)

Reset the hidden state, if any.

Source code in evotorch/neuroevolution/net/statefulmodule.py
def reset(self):
    """
    Reset the hidden state, if any.
    """
    self._hidden = None

ensure_stateful(net)

Ensure that a module is wrapped by StatefulModule.

If the given module is already wrapped by StatefulModule, then the module itself is returned. If the given module is not wrapped by StatefulModule, then this function first wraps the module via a new StatefulModule instance, and then this new wrapper is returned.

Parameters:

Name Type Description Default
net Module

The torch.nn.Module to be wrapped by StatefulModule (if it is not already wrapped by it).

required

Returns:

Type Description

The module net, wrapped by StatefulModule.

Source code in evotorch/neuroevolution/net/statefulmodule.py
def ensure_stateful(net: nn.Module):
    """
    Ensure that a module is wrapped by StatefulModule.

    If the given module is already wrapped by StatefulModule, then the
    module itself is returned.
    If the given module is not wrapped by StatefulModule, then this function
    first wraps the module via a new StatefulModule instance, and then this
    new wrapper is returned.

    Args:
        net: The `torch.nn.Module` to be wrapped by StatefulModule (if it is
            not already wrapped by it).
    Returns:
        The module `net`, wrapped by StatefulModule.
    """
    if not isinstance(net, StatefulModule):
        return StatefulModule(net)