add gene type RNN

This commit is contained in:
wls2002
2023-07-19 15:43:49 +08:00
parent 0a2a9fd1be
commit a684e6584d
18 changed files with 248 additions and 129 deletions

View File

@@ -26,7 +26,7 @@ class NEAT:
state = State(
P=self.config['pop_size'],
N=self.config['maximum_nodes'],
C=self.config['maximum_connections'],
C=self.config['maximum_conns'],
S=self.config['maximum_species'],
NL=1 + len(self.gene_type.node_attrs), # node length = (key) + attributes
CL=3 + len(self.gene_type.conn_attrs), # conn length = (in, out, key) + attributes
@@ -64,11 +64,15 @@ class NEAT:
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
generation=generation,
next_node_key=next_node_key,
next_species_key=next_species_key
# avoid jax auto cast from int to float. that would cause re-compilation.
generation=jnp.asarray(generation, dtype=jnp.int32),
next_node_key=jnp.asarray(next_node_key, dtype=jnp.float32),
next_species_key=jnp.asarray(next_species_key)
)
# move to device
state = jax.device_put(state)
return state
def step(self, state, fitness):