modify pipeline for "update_by_data";

fix bug in speciate. currently, node_delete and conn_delete can successfully work
This commit is contained in:
wls2002
2024-05-31 15:32:56 +08:00
parent 3ea9986bd4
commit 6aa9011043
12 changed files with 132 additions and 45 deletions

View File

@@ -117,7 +117,9 @@ class DefaultGenome(BaseGenome):
def hit():
batch_ins, new_conn_attrs = jax.vmap(
self.conn_gene.update_by_batch, in_axes=(None, 1, 1), out_axes=(1, 1)
self.conn_gene.update_by_batch,
in_axes=(None, 1, 1),
out_axes=(1, 1),
)(state, u_conns_[:, :, i], batch_values)
batch_z, new_node_attrs = self.node_gene.update_by_batch(
state,
@@ -132,12 +134,12 @@ class DefaultGenome(BaseGenome):
u_conns_.at[:, :, i].set(new_conn_attrs),
)
# the val of input nodes is obtained by the task, not by calculation
(batch_values, nodes_attrs_, u_conns_) = jax.lax.cond(
jnp.isin(i, self.input_idx),
lambda: (batch_values, nodes_attrs_, u_conns_),
hit,
)
# the val of input nodes is obtained by the task, not by calculation
return batch_values, nodes_attrs_, u_conns_, idx + 1