From 1d606eb1c3a8b6f9a0dd73c52e4a22a21f6296f1 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 10 Jul 2024 11:30:20 +0800 Subject: [PATCH] fix bug; feat: Add support for max_nodes and max_conns in DefaultGenome initialization --- examples/tmp.py | 2 +- tensorneat/algorithm/neat/genome/base.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/tmp.py b/examples/tmp.py index 7033483..0f7ddae 100644 --- a/examples/tmp.py +++ b/examples/tmp.py @@ -4,7 +4,7 @@ from tensorneat.algorithm import NEAT from tensorneat.algorithm.neat import DefaultGenome key = jax.random.key(0) -genome = DefaultGenome(num_inputs=5, num_outputs=3, init_hidden_layers=(1, )) +genome = DefaultGenome(num_inputs=5, num_outputs=3, max_nodes=100, max_conns=500, init_hidden_layers=()) state = genome.setup() nodes, conns = genome.initialize(state, key) print(genome.repr(state, nodes, conns)) diff --git a/tensorneat/algorithm/neat/genome/base.py b/tensorneat/algorithm/neat/genome/base.py index e892f3f..aa716c3 100644 --- a/tensorneat/algorithm/neat/genome/base.py +++ b/tensorneat/algorithm/neat/genome/base.py @@ -64,6 +64,7 @@ class BaseGenome(StatefulBaseClass): all_init_conns_in_idx.append(in_idx) all_init_conns_out_idx.append(out_idx) all_init_nodes.extend(in_layer) + all_init_nodes.extend(layer_indices[-1]) if max_nodes < len(all_init_nodes): raise ValueError( @@ -91,6 +92,7 @@ class BaseGenome(StatefulBaseClass): self.output_idx = np.array(layer_indices[-1]) self.all_init_nodes = np.array(all_init_nodes) self.all_init_conns = np.c_[all_init_conns_in_idx, all_init_conns_out_idx] + print(self.output_idx) def setup(self, state=State()): state = self.node_gene.setup(state)