use black format all files;
remove "return state" for functions which will be executed in vmap; recover randkey as args in mutation methods
This commit is contained in:
@@ -3,9 +3,8 @@ from jax.tree_util import register_pytree_node_class
|
||||
|
||||
@register_pytree_node_class
|
||||
class State:
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__['state_dict'] = kwargs
|
||||
self.__dict__["state_dict"] = kwargs
|
||||
|
||||
def registered_keys(self):
|
||||
return self.state_dict.keys()
|
||||
|
||||
Reference in New Issue
Block a user