fix bug in crossover: the child from two normal networks should always be normal.

This commit is contained in:
wls2002
2024-05-22 10:27:32 +08:00
parent d1559317d1
commit 6a37563696
11 changed files with 46 additions and 43 deletions

View File

@@ -154,8 +154,8 @@ class DefaultMutation(BaseMutation):
nodes_keys = jax.random.split(k1, num=nodes.shape[0])
conns_keys = jax.random.split(k2, num=conns.shape[0])
new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(0, 0))(nodes_keys, nodes)
new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(0, 0))(conns_keys, conns)
new_nodes = jax.vmap(genome.node_gene.mutate)(nodes_keys, nodes)
new_conns = jax.vmap(genome.conn_gene.mutate)(conns_keys, conns)
# nan nodes not changed
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)