Source code for evox.core.module

import dataclasses
import warnings
from functools import wraps, partial
from collections import namedtuple
from typing import Annotated, Any, Callable, Optional, Tuple, TypeVar, get_type_hints

import jax
import jax.numpy as jnp
import numpy as np
from jax.tree_util import register_pytree_node, tree_map, tree_leaves


from .state import State


def use_state(func: Callable, index: int = None):
    """Decorator for easy state management.

    This decorator will try to extract the sub-state belong to the module from current state
    and merge the result back to current state.

    Parameters
    ----------
    func
        The method to be wrapped with
    index
        The index of a batch state to use.
        Typically used to handle batched states created from `State.batch`.
    """

    err_msg = "Expect last return value must be State, got {}"

    def wrapper(self, state: State, *args, **kwargs):
        assert isinstance(
            state, State
        ), f"The first argument must be `State`, got {type(state)}"
        if not hasattr(self, "_node_id") or not hasattr(self, "_module_name"):
            raise ValueError(
                f"{self} is not initialized, did you forget to call `init`?"
            )

        # find the state that match the current module
        path, matched_state = state.find_path_to(self._node_id, self._module_name)

        if index is not None:
            extracted_state = tree_map(lambda x: x[index], matched_state)
            this_module = tree_map(lambda x: x[index], self)
        else:
            extracted_state = matched_state
            this_module = self

        if hasattr(func, "__self__"):
            # bounded method, don't pass self
            return_value = func(extracted_state, *args, **kwargs)
        else:
            # unbounded method (class method), pass self
            return_value = func(this_module, extracted_state, *args, **kwargs)

        # single return value, the value must be a State
        if not isinstance(return_value, tuple):
            assert isinstance(return_value, State), err_msg.format(type(return_value))
            aux, new_state = None, return_value
        else:
            # unpack the return value first
            assert isinstance(return_value[-1], State), err_msg.format(
                type(return_value[-1])
            )
            aux, new_state = return_value[:-1], return_value[-1]

        # if index is specified, apply the index to the state
        if index is not None:
            new_state = tree_map(
                lambda batch_arr, new_arr: batch_arr.at[index].set(new_arr),
                matched_state,
                new_state,
            )

        state = state.replace_by_path(
            path, new_state.clear_callbacks()
        ).prepend_closure(new_state)

        if aux is None:
            return state
        else:
            return (*aux, state)

    if hasattr(func, "__self__"):
        return wraps(func)(partial(wrapper, func.__self__))
    else:
        return wraps(func)(wrapper)


def jit_method(method: Callable):
    """Decorator for methods, wrapper the method with jax.jit, and set self as static argument.

    Parameters
    ----------
    method
        A python method

    Returns
    -------
    function
        A jit wrapped version of this method
    """
    return jax.jit(
        method,
        static_argnums=[
            0,
        ],
    )


def default_jit_func(name: str):
    if name == "__call__":
        return True

    if name.startswith("_"):
        return False

    return True


def jit_class(cls):
    """A helper function used to jit decorators to methods of a class

    Returns
    -------
    class
        a class with selected methods wrapped
    """
    for attr_name in dir(cls):
        func = getattr(cls, attr_name)
        if callable(func) and default_jit_func(attr_name):
            if dataclasses.is_dataclass(cls):
                wrapped = jax.jit(func)
            else:
                wrapped = jit_method(func)
            setattr(cls, attr_name, wrapped)
    return cls


[docs] class Stateful: """Base class for all evox modules. This module allow easy managing of states. All the constants (e.g. hyperparameters) are initialized in the ``__init__``, and mutated states are initialized in the ``setup`` method. The ``init`` method will automatically call the ``setup`` of the current module and recursively call ``setup`` methods of all submodules. Currently, there are two special metadata that can be used to control the behavior of the module initialization: - ``stack``: If set to True, the module will be initialized multiple times, and the states will be stacked together. - ``nested``: If set to True, the a list of modules, that is [module1, module2, ...], will be iterated and initialized. """
[docs] def __init__(self): super().__init__() object.__setattr__(self, "_node_id", None) object.__setattr__(self, "_module_name", None)
[docs] def setup(self, key: jax.Array) -> State: """Setup mutable state here The state it self is immutable, but it act as a mutable state by returning new state each time. Parameters ---------- key A PRNGKey. Returns ------- State The state of this module. """ return State()
def _recursive_init( self, key: jax.Array, node_id: int, module_name: str, no_state: bool, re_init: bool, ) -> Tuple[State, int]: if not re_init: object.__setattr__(self, "_node_id", node_id) object.__setattr__(self, "_module_name", module_name) if not no_state: child_states = {} # Find all submodules and sort them according to their name. # Sorting is important because it makes sure that the node_id # is deterministic across different runs. SubmoduleInfo = namedtuple("Submodule", ["name", "module", "metadata"]) submodules = [] # preprocess and sort to make sure the order is deterministic # otherwise the node_id will be different across different runs # making save/load impossible if dataclasses.is_dataclass(self): for field in dataclasses.fields(self): attr = getattr(self, field.name) if isinstance(attr, Stateful): submodules.append(SubmoduleInfo(field.name, attr, field.metadata)) # handle "nested" field if field.metadata.get("nested", False): for idx, nested_module in enumerate(attr): submodules.append( SubmoduleInfo( field.name + str(idx), nested_module, field.metadata ) ) else: for attr_name in vars(self): attr = getattr(self, attr_name) if not attr_name.startswith("_") and isinstance(attr, Stateful): submodules.append(SubmoduleInfo(attr_name, attr, {})) submodules.sort() for attr_name, attr, metadata in submodules: if key is None: subkey = None else: key, subkey = jax.random.split(key) # handle "Stack" # attr should be a list, or tuple of modules if metadata.get("stack", False): num_copies = len(attr) subkeys = jax.random.split(subkey, num_copies) current_node_id = node_id _, node_id = attr._recursive_init( None, node_id + 1, attr_name, True, re_init ) submodule_state, _node_id = jax.vmap( partial( Stateful._recursive_init, node_id=current_node_id + 1, module_name=attr_name, no_state=no_state, re_init=re_init, ) )(attr, subkeys) else: submodule_state, node_id = attr._recursive_init( subkey, node_id + 1, attr_name, no_state, re_init ) if not no_state: assert isinstance( submodule_state, State ), "setup method must return a State" child_states[attr_name] = submodule_state if no_state: return None, node_id else: self_state = self.setup(key) if dataclasses.is_dataclass(self_state): # if the setup method return a dataclass, convert it to State first self_state = State.from_dataclass(self_state) self_state._set_state_id_mut(self._node_id)._set_child_states_mut( child_states ) return self_state, node_id
[docs] def init( self, key: jax.Array = None, no_state: bool = False, re_init: bool = False ) -> State: """Initialize this module and all submodules This method should not be overwritten. Parameters ---------- key A PRNGKey. Returns ------- State The state of this module and all submodules combined. """ state, _node_id = self._recursive_init(key, 0, None, no_state, re_init) return state
[docs] def parallel_init( self, key: jax.Array, num_copies: int, no_state: bool = False ) -> Tuple[State, int]: """Initialize multiple copies of this module in parallel This method should not be overwritten. Parameters ---------- key A PRNGKey. num_copies The number of copies to be initialized no_state Whether to skip the state initialization Returns ------- Tuple[State, int] The state of this module and all submodules combined, and the last node_id """ subkeys = jax.random.split(key, num_copies) return jax.vmap(self.init, in_axes=(0, None))(subkeys, no_state)
@classmethod def stack(cls, stateful_objs, axis=0): for obj in stateful_objs: assert dataclasses.is_dataclass(obj), "All objects must be dataclasses" def stack_arrays(array, *arrays): return jnp.stack((array, *arrays), axis=axis) return tree_map(stack_arrays, stateful_objs[0], *stateful_objs[1:]) def __len__(self) -> int: """ Inspect the length of the first element in the state, usually paired with `Stateful.stack` to read the batch size """ assert dataclasses.is_dataclass(self), "Length is only supported for dataclass" return len(tree_leaves(self)[0])