complete normal neat algorithm
This commit is contained in:
@@ -37,9 +37,17 @@ def initialize_genomes(state: State, gene_type: Type[BaseGene]):
|
||||
pop_nodes = np.tile(o_nodes, (state.P, 1, 1))
|
||||
pop_conns = np.tile(o_conns, (state.P, 1, 1))
|
||||
|
||||
return pop_nodes, pop_conns
|
||||
return jax.device_put([pop_nodes, pop_conns])
|
||||
|
||||
|
||||
def count(nodes: Array, cons: Array):
|
||||
"""
|
||||
Count how many nodes and connections are in the genome.
|
||||
"""
|
||||
node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0]))
|
||||
cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
|
||||
return node_cnt, cons_cnt
|
||||
|
||||
def add_node(nodes: Array, cons: Array, new_key: int, attrs: Array) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Add a new node to the genome.
|
||||
|
||||
Reference in New Issue
Block a user