State Class#

class evox.State[source]#

A class represents state

State is immutable, to update state, use the update method or the | operator. State has already implemented tree_flatten, tree_unflatten and has registered as a valid pytree node. So it can be used as pytree with JAX without any issue.

__init__(_dataclass=None, /, **kwargs)[source]#

Construct a State from dataclass instance or keyword arguments

Example::
>>> from evox import State
>>> State(x=1, y=2) # from keyword arguments
State({'x': 1, 'y': 2}, {})
>>> from dataclasses import dataclass
>>> @dataclass
>>> class Param:
...     x: int
...     y: int
...
>>> param = Param(x=1, y=2)
>>> State(param) # from dataclass instance
State(Param(x=1, y=2), {})
Return type:

None

clear_callbacks()[source]#

Clear all the callbacks in the state

Return type:

Self

execute_callbacks(clear_closures=True)[source]#

Execute all the callbacks in the state

Return type:

Self

find_path_to(node_id, hint=None)[source]#

Find the state with node_id matching the state_id A hint can be given with the module_name

Parameters:
  • node_id (int)

  • hint (str | None)

Return type:

Tuple[Tuple | int, Self] | None

classmethod from_dataclass(dataclass)[source]#

Construct a State from dataclass instance

Example::
>>> from evox import State
>>> from dataclasses import dataclass
>>> @dataclass
>>> class Param:
...     x: int
...     y: int
...
>>> param = Param(x=1, y=2)
>>> State.from_dataclass(param)
State(Param(x=1, y=2), {})
Return type:

Self

index(index)[source]#

PyTree index, apply the index to every element in the state.

Parameters:

index (str | int)

Return type:

Self

load(path, orbax=True)[source]#

Load the saved state from disk

Parameters:
  • path (str) – The path to load the state

  • orbax (bool, default: True) – If True, use orbax to load the state, otherwise use pickle

Return type:

Self

prepend_closure(other)[source]#

Prepend closures stored in others to the current state

Parameters:

other (Self)

Return type:

Self

query_state(name)[source]#

Recursively find a sub-state by a query name. eg: ‘foo.bar’ will find a sub state named foo, then find bar under sub-states of foo

Parameters:

name (str)

Return type:

Self

register_callback(callback, *args, **kwargs)[source]#

Add a callback to the state

Parameters:

callback (Callable)

Return type:

Self

replace(**kwargs)[source]#

Update the current State with another State or dict and return new State.

This method also accept keyword arguments.

Example::
>>> from evox import State
>>> state = State(x=1, y=2)
>>> state.replace(y=3) # use the update method
State ({'x': 1, 'y': 3}, {})
>>> state # note that State is immutable, so state isn't modified
State ({'x': 1, 'y': 2}, {})
Return type:

Self

save(path, orbax=True)[source]#

Save the state to local filesystem

Parameters:
  • path (str) – The path to save the state

  • orbax (bool, default: True) – If True, use orbax to save the state, otherwise use pickle

Return type:

None