diff --git a/tensorneat/algorithm/neat/species/default.py b/tensorneat/algorithm/neat/species/default.py index 323841c..d5b297d 100644 --- a/tensorneat/algorithm/neat/species/default.py +++ b/tensorneat/algorithm/neat/species/default.py @@ -20,7 +20,8 @@ class DefaultSpecies(BaseSpecies): survival_threshold: float = 0.2, min_species_size: int = 1, compatibility_threshold: float = 3., - initialize_method: str = 'one_hidden_node', # {'one_hidden_node', 'dense_hideen_layer', 'no_hidden_random'} + initialize_method: str = 'one_hidden_node', + # {'one_hidden_node', 'dense_hideen_layer', 'no_hidden_random'} ): self.genome = genome self.pop_size = pop_size @@ -41,7 +42,7 @@ class DefaultSpecies(BaseSpecies): def setup(self, randkey): k1, k2 = jax.random.split(randkey, 2) - pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome,k1, self.initialize_method) + pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome, k1, self.initialize_method) species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species best_fitness = jnp.full((self.species_size,), jnp.nan) # the best fitness of each species @@ -492,11 +493,9 @@ class DefaultSpecies(BaseSpecies): return jnp.where(max_cnt == 0, 0, val / max_cnt) - - def initialize_population(pop_size, genome, randkey, init_method='default'): rand_keys = jax.random.split(randkey, pop_size) - + if init_method == 'one_hidden_node': init_func = init_one_hidden_node elif init_method == 'dense_hideen_layer': @@ -510,6 +509,7 @@ def initialize_population(pop_size, genome, randkey, init_method='default'): return pop_nodes, pop_conns + # one hidden node def init_one_hidden_node(genome, randkey): input_idx, output_idx = genome.input_idx, genome.output_idx @@ -520,12 +520,14 @@ def init_one_hidden_node(genome, randkey): nodes = nodes.at[input_idx, 0].set(input_idx) nodes = nodes.at[output_idx, 0].set(output_idx) - nodes = nodes.at[new_node_key, 0].set(new_node_key) + nodes = nodes.at[new_node_key, 0].set(new_node_key) rand_keys_nodes = jax.random.split(randkey, num=len(input_idx) + len(output_idx) + 1) - input_keys, output_keys, hidden_key = rand_keys_nodes[:len(input_idx)], rand_keys_nodes[len(input_idx):len(input_idx) + len(output_idx)], rand_keys_nodes[-1] + input_keys, output_keys, hidden_key = rand_keys_nodes[:len(input_idx)], rand_keys_nodes[ + len(input_idx):len(input_idx) + len( + output_idx)], rand_keys_nodes[-1] - node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(None,0)) + node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(None, 0)) input_attrs = node_attr_func(input_keys) output_attrs = node_attr_func(output_keys) hidden_attrs = genome.node_gene.new_custom_attrs(hidden_key) @@ -540,12 +542,12 @@ def init_one_hidden_node(genome, randkey): output_conns = jnp.c_[jnp.full_like(output_idx, new_node_key), output_idx] conns = conns.at[output_idx, 0:2].set(output_conns) - conns = conns.at[output_idx, 2].set(True) + conns = conns.at[output_idx, 2].set(True) rand_keys_conns = jax.random.split(randkey, num=len(input_idx) + len(output_idx)) input_conn_keys, output_conn_keys = rand_keys_conns[:len(input_idx)], rand_keys_conns[len(input_idx):] - conn_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(None,0)) + conn_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(None, 0)) input_conn_attrs = conn_attr_func(input_conn_keys) output_conn_attrs = conn_attr_func(output_conn_keys) @@ -553,10 +555,10 @@ def init_one_hidden_node(genome, randkey): conns = conns.at[output_idx, 3:].set(output_conn_attrs) return nodes, conns - -#random dense connections with 1 hidden layer -def init_dense_hideen_layer( genome, randkey,hiddens=20): + +# random dense connections with 1 hidden layer +def init_dense_hideen_layer(genome, randkey, hiddens=20): k1, k2, k3 = jax.random.split(randkey, num=3) input_idx, output_idx = genome.input_idx, genome.output_idx input_size = len(input_idx) @@ -602,6 +604,7 @@ def init_dense_hideen_layer( genome, randkey,hiddens=20): return nodes, conns + # random sparse connections with no hidden nodes def init_no_hidden_random(genome, randkey): k1, k2, k3 = jax.random.split(randkey, num=3) diff --git a/tensorneat/utils/state.py b/tensorneat/utils/state.py index 9e3932e..8212129 100644 --- a/tensorneat/utils/state.py +++ b/tensorneat/utils/state.py @@ -7,7 +7,19 @@ class State: def __init__(self, **kwargs): self.__dict__['state_dict'] = kwargs + def registered_keys(self): + return self.state_dict.keys() + + def register(self, **kwargs): + for key in kwargs: + if key in self.registered_keys(): + raise ValueError(f"Key {key} already exists in state") + return State(**{**self.state_dict, **kwargs}) + def update(self, **kwargs): + for key in kwargs: + if key not in self.registered_keys(): + raise ValueError(f"Key {key} does not exist in state") return State(**{**self.state_dict, **kwargs}) def __getattr__(self, name): @@ -26,4 +38,4 @@ class State: @classmethod def tree_unflatten(cls, aux_data, children): - return cls(**dict(zip(aux_data, children))) \ No newline at end of file + return cls(**dict(zip(aux_data, children)))