finish all refactoring
This commit is contained in:
@@ -92,7 +92,7 @@ class DefaultMutation(BaseMutation):
|
||||
return nodes_, conns_
|
||||
|
||||
def successful():
|
||||
return nodes_, genome.add_conn(conns_, i_key, o_key, True, genome.conns.new_custom_attrs())
|
||||
return nodes_, genome.add_conn(conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs())
|
||||
|
||||
def already_exist():
|
||||
return nodes_, conns_.at[conn_pos, 2].set(True)
|
||||
@@ -105,11 +105,12 @@ class DefaultMutation(BaseMutation):
|
||||
return jax.lax.cond(
|
||||
is_already_exist,
|
||||
already_exist,
|
||||
jax.lax.cond(
|
||||
is_cycle,
|
||||
nothing,
|
||||
successful
|
||||
)
|
||||
lambda:
|
||||
jax.lax.cond(
|
||||
is_cycle,
|
||||
nothing,
|
||||
successful
|
||||
)
|
||||
)
|
||||
|
||||
elif genome.network_type == 'recurrent':
|
||||
@@ -138,23 +139,23 @@ class DefaultMutation(BaseMutation):
|
||||
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
|
||||
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
|
||||
|
||||
def no(k, g):
|
||||
return g
|
||||
def no(key_, nodes_, conns_):
|
||||
return nodes_, conns_
|
||||
|
||||
genome = jax.lax.cond(r1 < self.node_add, mutate_add_node, no, k1, nodes, conns)
|
||||
genome = jax.lax.cond(r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns)
|
||||
genome = jax.lax.cond(r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns)
|
||||
genome = jax.lax.cond(r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns)
|
||||
nodes, conns = jax.lax.cond(r1 < self.node_add, mutate_add_node, no, k1, nodes, conns)
|
||||
nodes, conns = jax.lax.cond(r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns)
|
||||
nodes, conns = jax.lax.cond(r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns)
|
||||
nodes, conns = jax.lax.cond(r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns)
|
||||
|
||||
return genome
|
||||
return nodes, conns
|
||||
|
||||
def mutate_values(self, randkey, genome, nodes, conns):
|
||||
k1, k2 = jax.random.split(randkey, num=2)
|
||||
nodes_keys = jax.random.split(k1, num=genome.nodes.shape[0])
|
||||
conns_keys = jax.random.split(k2, num=genome.conns.shape[0])
|
||||
nodes_keys = jax.random.split(k1, num=nodes.shape[0])
|
||||
conns_keys = jax.random.split(k2, num=conns.shape[0])
|
||||
|
||||
new_nodes = jax.vmap(genome.nodes.mutate, in_axes=(0, 0))(nodes_keys, nodes)
|
||||
new_conns = jax.vmap(genome.conns.mutate, in_axes=(0, 0))(conns_keys, conns)
|
||||
new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(0, 0))(nodes_keys, nodes)
|
||||
new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(0, 0))(conns_keys, conns)
|
||||
|
||||
# nan nodes not changed
|
||||
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)
|
||||
|
||||
Reference in New Issue
Block a user