add gene type RNN
This commit is contained in:
@@ -26,7 +26,7 @@ class NEAT:
|
||||
state = State(
|
||||
P=self.config['pop_size'],
|
||||
N=self.config['maximum_nodes'],
|
||||
C=self.config['maximum_connections'],
|
||||
C=self.config['maximum_conns'],
|
||||
S=self.config['maximum_species'],
|
||||
NL=1 + len(self.gene_type.node_attrs), # node length = (key) + attributes
|
||||
CL=3 + len(self.gene_type.conn_attrs), # conn length = (in, out, key) + attributes
|
||||
@@ -64,11 +64,15 @@ class NEAT:
|
||||
idx2species=idx2species,
|
||||
center_nodes=center_nodes,
|
||||
center_conns=center_conns,
|
||||
generation=generation,
|
||||
next_node_key=next_node_key,
|
||||
next_species_key=next_species_key
|
||||
# avoid jax auto cast from int to float. that would cause re-compilation.
|
||||
generation=jnp.asarray(generation, dtype=jnp.int32),
|
||||
next_node_key=jnp.asarray(next_node_key, dtype=jnp.float32),
|
||||
next_species_key=jnp.asarray(next_species_key)
|
||||
)
|
||||
|
||||
# move to device
|
||||
state = jax.device_put(state)
|
||||
|
||||
return state
|
||||
|
||||
def step(self, state, fitness):
|
||||
|
||||
Reference in New Issue
Block a user