update class State. Add method register and update method update.

This commit is contained in:
wls2002
2024-05-25 16:05:47 +08:00
parent 25f66dc2fb
commit 3b2f917aee
2 changed files with 29 additions and 14 deletions

View File

@@ -20,7 +20,8 @@ class DefaultSpecies(BaseSpecies):
survival_threshold: float = 0.2,
min_species_size: int = 1,
compatibility_threshold: float = 3.,
initialize_method: str = 'one_hidden_node', # {'one_hidden_node', 'dense_hideen_layer', 'no_hidden_random'}
initialize_method: str = 'one_hidden_node',
# {'one_hidden_node', 'dense_hideen_layer', 'no_hidden_random'}
):
self.genome = genome
self.pop_size = pop_size
@@ -492,8 +493,6 @@ class DefaultSpecies(BaseSpecies):
return jnp.where(max_cnt == 0, 0, val / max_cnt)
def initialize_population(pop_size, genome, randkey, init_method='default'):
rand_keys = jax.random.split(randkey, pop_size)
@@ -510,6 +509,7 @@ def initialize_population(pop_size, genome, randkey, init_method='default'):
return pop_nodes, pop_conns
# one hidden node
def init_one_hidden_node(genome, randkey):
input_idx, output_idx = genome.input_idx, genome.output_idx
@@ -523,7 +523,9 @@ def init_one_hidden_node(genome, randkey):
nodes = nodes.at[new_node_key, 0].set(new_node_key)
rand_keys_nodes = jax.random.split(randkey, num=len(input_idx) + len(output_idx) + 1)
input_keys, output_keys, hidden_key = rand_keys_nodes[:len(input_idx)], rand_keys_nodes[len(input_idx):len(input_idx) + len(output_idx)], rand_keys_nodes[-1]
input_keys, output_keys, hidden_key = rand_keys_nodes[:len(input_idx)], rand_keys_nodes[
len(input_idx):len(input_idx) + len(
output_idx)], rand_keys_nodes[-1]
node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(None, 0))
input_attrs = node_attr_func(input_keys)
@@ -602,6 +604,7 @@ def init_dense_hideen_layer( genome, randkey,hiddens=20):
return nodes, conns
# random sparse connections with no hidden nodes
def init_no_hidden_random(genome, randkey):
k1, k2, k3 = jax.random.split(randkey, num=3)

View File

@@ -7,7 +7,19 @@ class State:
def __init__(self, **kwargs):
self.__dict__['state_dict'] = kwargs
def registered_keys(self):
return self.state_dict.keys()
def register(self, **kwargs):
for key in kwargs:
if key in self.registered_keys():
raise ValueError(f"Key {key} already exists in state")
return State(**{**self.state_dict, **kwargs})
def update(self, **kwargs):
for key in kwargs:
if key not in self.registered_keys():
raise ValueError(f"Key {key} does not exist in state")
return State(**{**self.state_dict, **kwargs})
def __getattr__(self, name):