modify pipeline for "update_by_data";
fix bug in speciate. currently, node_delete and conn_delete can successfully work
This commit is contained in:
@@ -117,7 +117,9 @@ class DefaultGenome(BaseGenome):
|
||||
|
||||
def hit():
|
||||
batch_ins, new_conn_attrs = jax.vmap(
|
||||
self.conn_gene.update_by_batch, in_axes=(None, 1, 1), out_axes=(1, 1)
|
||||
self.conn_gene.update_by_batch,
|
||||
in_axes=(None, 1, 1),
|
||||
out_axes=(1, 1),
|
||||
)(state, u_conns_[:, :, i], batch_values)
|
||||
batch_z, new_node_attrs = self.node_gene.update_by_batch(
|
||||
state,
|
||||
@@ -132,12 +134,12 @@ class DefaultGenome(BaseGenome):
|
||||
u_conns_.at[:, :, i].set(new_conn_attrs),
|
||||
)
|
||||
|
||||
# the val of input nodes is obtained by the task, not by calculation
|
||||
(batch_values, nodes_attrs_, u_conns_) = jax.lax.cond(
|
||||
jnp.isin(i, self.input_idx),
|
||||
lambda: (batch_values, nodes_attrs_, u_conns_),
|
||||
hit,
|
||||
)
|
||||
# the val of input nodes is obtained by the task, not by calculation
|
||||
|
||||
return batch_values, nodes_attrs_, u_conns_, idx + 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user