create state

This commit is contained in:
wls2002
2023-07-14 17:27:22 +08:00
parent 7265e33c43
commit a0a1ef6c58
41 changed files with 43 additions and 2882 deletions

0
algorithm/__init__.py Normal file
View File

View File

View File

29
algorithm/state.py Normal file
View File

@@ -0,0 +1,29 @@
from jax.tree_util import register_pytree_node_class, tree_map
@register_pytree_node_class
class State:
def __init__(self, **kwargs):
self.__dict__['state_dict'] = kwargs
def update(self, **kwargs):
return State(**{**self.state_dict, **kwargs})
def __getattr__(self, name):
return self.state_dict[name]
def __setattr__(self, name, value):
raise AttributeError("State is immutable")
def __repr__(self):
return f"State ({self.state_dict})"
def tree_flatten(self):
children = list(self.state_dict.values())
aux_data = list(self.state_dict.keys())
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(**dict(zip(aux_data, children)))