Source code for evox.core.state

import os
from pprint import pformat
from typing import Any, Optional, Tuple, Union, Callable
from typing_extensions import Self
from copy import copy
from pathlib import Path
import pickle
import dataclasses

import orbax.checkpoint as ocp
import warnings
from jax.tree_util import register_pytree_node_class, tree_map
from .distributed import ShardingType


PathLike = Union[str, bytes, os.PathLike]


def is_magic_method(name: str):
    return name.startswith("__") and name.endswith("__")


def linkedlist_prepend(lst, item):
    return (item, lst)


def linkedlist_concat(lst1, lst2):
    """Return the concatenation of two linked lists"""
    result = lst2
    for elem in reversed(linkedlist_to_list(lst1)):
        result = linkedlist_prepend(result, elem)

    return result


def linkedlist_to_list(lst):
    """Convert a linked list to a python list"""
    result = []
    iter_lst = lst
    while iter_lst:
        first, iter_lst = iter_lst
        result.append(first)

    return result


[docs] @register_pytree_node_class class State: """A class represents state ``State`` is immutable, to update state, use the ``update`` method or the ``|`` operator. ``State`` has already implemented ``tree_flatten``, ``tree_unflatten`` and has registered as a valid pytree node. So it can be used as pytree with JAX without any issue. """ EMPTY: dict = {}
[docs] def __init__(self, _dataclass=None, /, **kwargs) -> None: """Construct a ``State`` from dataclass instance or keyword arguments Example:: >>> from evox import State >>> State(x=1, y=2) # from keyword arguments State({'x': 1, 'y': 2}, {}) >>> from dataclasses import dataclass >>> @dataclass >>> class Param: ... x: int ... y: int ... >>> param = Param(x=1, y=2) >>> State(param) # from dataclass instance State(Param(x=1, y=2), {}) """ if _dataclass is not None: assert dataclasses.is_dataclass( _dataclass ), "when using the positional argument, it must be a dataclass" self.__dict__["_state_dict"] = _dataclass else: self.__dict__["_state_dict"] = kwargs self.__dict__["_child_states"] = State.EMPTY self.__dict__["_state_id"] = None # store closures in the state # and restore the value separately so that they are compatible with jax's transformation # it's stored as a linked list to satisfy the functional programming paradigm self.__dict__["_callbacks"] = () self.__dict__["_closure_values"] = ()
[docs] @classmethod def from_dataclass(cls, dataclass) -> Self: """Construct a ``State`` from dataclass instance Example:: >>> from evox import State >>> from dataclasses import dataclass >>> @dataclass >>> class Param: ... x: int ... y: int ... >>> param = Param(x=1, y=2) >>> State.from_dataclass(param) State(Param(x=1, y=2), {}) """ return cls(dataclass)
def _set_state_dict_mut(self, state_dict: dict) -> Self: """Force set child state and return self This method mutate the struture itself and is not pure. Use with cautious. """ self.__dict__["_state_dict"] = state_dict return self def _set_child_states_mut(self, child_states: dict) -> Self: """Force set child state and return self This method mutate the struture itself and is not pure. Use with cautious. """ self.__dict__["_child_states"] = child_states return self def _set_state_id_mut(self, state_id) -> Self: """Force set the state id and return self This method mutate the struture itself and is not pure. Use with cautious. """ self.__dict__["_state_id"] = state_id return self def _set_closures_mut(self, callbacks, closure_values) -> Self: self.__dict__["_callbacks"] = callbacks self.__dict__["_closure_values"] = closure_values return self def update(self, **kwargs) -> Self: warnings.warn( "update() is depreacred, use replace() instead", DeprecationWarning ) return self.replace(**kwargs)
[docs] def replace(self, **kwargs) -> Self: """Update the current State with another State or dict and return new State. This method also accept keyword arguments. Example:: >>> from evox import State >>> state = State(x=1, y=2) >>> state.replace(y=3) # use the update method State ({'x': 1, 'y': 3}, {}) >>> state # note that State is immutable, so state isn't modified State ({'x': 1, 'y': 2}, {}) """ if dataclasses.is_dataclass(self._state_dict): return copy(self)._set_state_dict_mut( dataclasses.replace(self._state_dict, **kwargs) ) else: return copy(self)._set_state_dict_mut({**self._state_dict, **kwargs})
def has_child(self, name: str) -> bool: return name in self._child_states def get_child_state(self, name: str) -> Self: return self._child_states[name]
[docs] def query_state(self, name: str) -> Self: """ Recursively find a sub-state by a query name. eg: `'foo.bar'` will find a sub state named foo, then find `bar` under sub-states of `foo` """ child_state = self for child_state_name in name.split("."): child_state = child_state.get_child_state(child_state_name) return child_state
def update_child(self, name: str, child_state: Self) -> Self: warnings.warn("update_child() is depreacred, use replace_child() instead") return self.replace_child(name, child_state) def replace_child(self, name: str, child_state: Self) -> Self: return copy(self)._set_child_states_mut( {**self._child_states, name: child_state} )
[docs] def find_path_to( self, node_id: int, hint: Optional[str] = None ) -> Optional[Tuple[Union[Tuple, int], Self]]: """Find the state with node_id matching the state_id A hint can be given with the module_name """ if node_id == self._state_id: return node_id, self if hint in self._child_states and node_id == self._child_states[hint]._state_id: return (hint, node_id), self._child_states[hint] for child_id, child_state in self._child_states.items(): result = child_state.find_path_to(node_id) if result is not None: path, state = result return (child_id, path), state return None
def replace_by_path(self, path, new_state): if isinstance(path, int): assert path == self._state_id return new_state elif isinstance(path, tuple): child_id, path = path return self.replace_child( child_id, self._child_states[child_id].replace_by_path(path, new_state), ) else: raise ValueError("Path must be either tuple or int") def __getattr__(self, key: str) -> Any: if is_magic_method(key): return super().__getattr__(key) try: if dataclasses.is_dataclass(self._state_dict): return getattr(self._state_dict, key) else: return self._state_dict[key] except (AttributeError, KeyError) as e: raise KeyError( f"State has no attribute '{key}'." "This may be due to a mismatch between the state and the module. " "If you're trying to fit the state to a submodule, please use the `use_state` wrapper." ) from e def __getitem__(self, key: str) -> Any: return getattr(self, key)
[docs] def index(self, index: Union[str, int]) -> Self: """ PyTree index, apply the index to every element in the state. """ return tree_map(lambda x: x[index], self)
[docs] def register_callback(self, callback: Callable, *args, **kwargs) -> Self: """ Add a callback to the state """ callbacks = (callback, self._callbacks) closure_values = ((args, kwargs), self._closure_values) return copy(self)._set_closures_mut(callbacks, closure_values)
[docs] def clear_callbacks(self) -> Self: """ Clear all the callbacks in the state """ return copy(self)._set_closures_mut((), ())
[docs] def execute_callbacks(self, clear_closures=True) -> Self: """ Execute all the callbacks in the state """ closures = [] iter_callback = self._callbacks iter_values = self._closure_values while iter_callback: callback, iter_callback = iter_callback (args, kwargs), iter_values = iter_values closures.append((callback, args, kwargs)) closures.reverse() for callback, args, kwargs in closures: callback(*args, **kwargs) if clear_closures: return self.clear_callbacks() else: return self
[docs] def prepend_closure(self, other: Self) -> Self: """Prepend closures stored in others to the current state""" callbacks = linkedlist_concat(other._callbacks, self._callbacks) closure_values = linkedlist_concat(other._closure_values, self._closure_values) return copy(self)._set_closures_mut(callbacks, closure_values)
def __setattr__(self, _key: str, _value: Any) -> None: raise TypeError("State is immutable") def __setitem__(self, _key: str, _value: Any) -> None: raise TypeError("State is immutable") def __repr__(self) -> str: if self is State.EMPTY: return "State.empty" str_children = [ f"{repr(key)}: {repr(child_state)}" for key, child_state in self._child_states.items() ] str_children = "{" + ",".join(str_children) + "}" return f"State({repr(self._state_dict)}, {str_children})" def __str__(self) -> str: return f"State{pformat(self.sprint_tree())}" def sprint_tree(self) -> Union[dict, str]: if self is State.EMPTY: return "State.empty" children = { key: child_state.sprint_tree() for key, child_state in self._child_states.items() } return self._state_dict, children def tree_flatten(self) -> Tuple[Tuple[dict, dict], None]: children = (self._state_dict, self._child_states, self._closure_values) aux_data = (self._state_id, self._callbacks) return (children, aux_data) @classmethod def tree_unflatten(cls, aux_data: None, children: Tuple[dict, dict]): state_dict, child_states, closure_values = children state_id, callbacks = aux_data return ( cls() ._set_state_id_mut(state_id) ._set_state_dict_mut(state_dict) ._set_child_states_mut(child_states) ._set_closures_mut(callbacks, closure_values) ) def __eq__(self, other: Self): # TODO: verify the correctness of the comparison if self._state_dict != other._state_dict: return False return self._child_states == other._child_states
[docs] def save(self, path: PathLike, orbax: bool = True) -> None: """Save the state to local filesystem Parameters ---------- path: str The path to save the state orbax: bool, default: True If True, use orbax to save the state, otherwise use pickle """ path = Path(path).resolve() if orbax: ckpt = ocp.StandardCheckpointer() ckpt.save(path, args=ocp.args.StandardSave(self)) else: with path.open("wb") as f: pickle.dump(self, f)
[docs] def load(self, path: PathLike, orbax: bool = True) -> Self: """Load the saved state from disk Parameters ---------- path: str The path to load the state orbax: bool, default: True If True, use orbax to load the state, otherwise use pickle """ path = Path(path).resolve() if orbax: ckpt = ocp.StandardCheckpointer() state = ckpt.restore(path, args=ocp.args.StandardRestore(self)) else: with path.open("rb") as f: state = pickle.load(f) return state
def get_sharding(self, devices=None): state_dict = _get_state_sharding(self._state_dict, devices) child_states = { key: child_state.get_sharding(devices) for key, child_state in self._child_states.items() } sharding_plan = ( copy(self) ._set_state_dict_mut(state_dict) ._set_child_states_mut(child_states) ) return sharding_plan
def _get_state_sharding(obj, devices=None): """ Apply DFS like tree_flatten """ if isinstance(obj, dict): # dict type does not have metadata, so always return replicated sharding return { key: ShardingType.REPLICATED.get_sharding(devices) for key in obj.keys() } elif dataclasses.is_dataclass(obj): # dataclass type has metadata, so we can get the sharding type from the metadata sharding_plan = {} for field in dataclasses.fields(obj): sharding = field.metadata.get("sharding", ShardingType.REPLICATED) static = field.metadata.get("static", False) if not static: # static field does not need to be sharded sharding_plan[field.name] = sharding.get_sharding(devices) return obj.replace(**sharding_plan) else: raise ValueError(f"Unsupported type: {type(obj)}")