diff --git a/examples/tmp.py b/examples/tmp.py index 7033483..0f7ddae 100644 --- a/examples/tmp.py +++ b/examples/tmp.py @@ -4,7 +4,7 @@ from tensorneat.algorithm import NEAT from tensorneat.algorithm.neat import DefaultGenome key = jax.random.key(0) -genome = DefaultGenome(num_inputs=5, num_outputs=3, init_hidden_layers=(1, )) +genome = DefaultGenome(num_inputs=5, num_outputs=3, max_nodes=100, max_conns=500, init_hidden_layers=()) state = genome.setup() nodes, conns = genome.initialize(state, key) print(genome.repr(state, nodes, conns)) diff --git a/tensorneat/algorithm/neat/genome/base.py b/tensorneat/algorithm/neat/genome/base.py index e892f3f..aa716c3 100644 --- a/tensorneat/algorithm/neat/genome/base.py +++ b/tensorneat/algorithm/neat/genome/base.py @@ -64,6 +64,7 @@ class BaseGenome(StatefulBaseClass): all_init_conns_in_idx.append(in_idx) all_init_conns_out_idx.append(out_idx) all_init_nodes.extend(in_layer) + all_init_nodes.extend(layer_indices[-1]) if max_nodes < len(all_init_nodes): raise ValueError( @@ -91,6 +92,7 @@ class BaseGenome(StatefulBaseClass): self.output_idx = np.array(layer_indices[-1]) self.all_init_nodes = np.array(all_init_nodes) self.all_init_conns = np.c_[all_init_conns_in_idx, all_init_conns_out_idx] + print(self.output_idx) def setup(self, state=State()): state = self.node_gene.setup(state)