modify pipeline for "update_by_data";

fix bug in speciate. currently, node_delete and conn_delete can successfully work
This commit is contained in:
wls2002
2024-05-31 15:32:56 +08:00
parent 3ea9986bd4
commit 6aa9011043
12 changed files with 132 additions and 45 deletions

View File

@@ -178,18 +178,25 @@ class DefaultMutation(BaseMutation):
def no(key_, nodes_, conns_):
return nodes_, conns_
nodes, conns = jax.lax.cond(
r1 < self.node_add, mutate_add_node, no, k1, nodes, conns
)
nodes, conns = jax.lax.cond(
r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns
)
nodes, conns = jax.lax.cond(
r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns
)
nodes, conns = jax.lax.cond(
r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns
)
if self.node_add > 0:
nodes, conns = jax.lax.cond(
r1 < self.node_add, mutate_add_node, no, k1, nodes, conns
)
if self.node_delete > 0:
nodes, conns = jax.lax.cond(
r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns
)
if self.conn_add > 0:
nodes, conns = jax.lax.cond(
r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns
)
if self.conn_delete > 0:
nodes, conns = jax.lax.cond(
r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns
)
return nodes, conns