fix bug;
feat: Add support for max_nodes and max_conns in DefaultGenome initialization
This commit is contained in:
@@ -4,7 +4,7 @@ from tensorneat.algorithm import NEAT
|
||||
from tensorneat.algorithm.neat import DefaultGenome
|
||||
|
||||
key = jax.random.key(0)
|
||||
genome = DefaultGenome(num_inputs=5, num_outputs=3, init_hidden_layers=(1, ))
|
||||
genome = DefaultGenome(num_inputs=5, num_outputs=3, max_nodes=100, max_conns=500, init_hidden_layers=())
|
||||
state = genome.setup()
|
||||
nodes, conns = genome.initialize(state, key)
|
||||
print(genome.repr(state, nodes, conns))
|
||||
|
||||
@@ -64,6 +64,7 @@ class BaseGenome(StatefulBaseClass):
|
||||
all_init_conns_in_idx.append(in_idx)
|
||||
all_init_conns_out_idx.append(out_idx)
|
||||
all_init_nodes.extend(in_layer)
|
||||
all_init_nodes.extend(layer_indices[-1])
|
||||
|
||||
if max_nodes < len(all_init_nodes):
|
||||
raise ValueError(
|
||||
@@ -91,6 +92,7 @@ class BaseGenome(StatefulBaseClass):
|
||||
self.output_idx = np.array(layer_indices[-1])
|
||||
self.all_init_nodes = np.array(all_init_nodes)
|
||||
self.all_init_conns = np.c_[all_init_conns_in_idx, all_init_conns_out_idx]
|
||||
print(self.output_idx)
|
||||
|
||||
def setup(self, state=State()):
|
||||
state = self.node_gene.setup(state)
|
||||
|
||||
Reference in New Issue
Block a user