All function with state will update the state and return it.

Remove randkey args in functions with state, since it can attach the randkey by states.
This commit is contained in:
wls2002
2024-05-25 20:45:57 +08:00
parent 5626fddf41
commit 79d53ea7af
12 changed files with 84 additions and 70 deletions

View File

@@ -38,7 +38,7 @@ class DefaultGenome(BaseGenome):
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
seqs = topological_sort(nodes, conn_enable)
return seqs, nodes, u_conns
return state, seqs, nodes, u_conns
def forward(self, state, inputs, transformed):
cal_seqs, nodes, conns = transformed
@@ -49,34 +49,32 @@ class DefaultGenome(BaseGenome):
nodes_attrs = nodes[:, 1:]
def cond_fun(carry):
values, idx = carry
state_, values, idx = carry
return (idx < N) & (cal_seqs[idx] != I_INT)
def body_func(carry):
values, idx = carry
state_, values, idx = carry
i = cal_seqs[idx]
def hit():
ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 1, 0))(conns[:, :, i], values)
z = self.node_gene.forward(state, nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
s, ins = jax.vmap(self.conn_gene.forward,
in_axes=(None, 1, 0), out_axes=(None, 0))(state_, conns[:, :, i], values)
s, z = self.node_gene.forward(s, nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
new_values = values.at[i].set(z)
return new_values
def miss():
return values
return s, new_values
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(
state_, values = jax.lax.cond(
jnp.isin(i, self.input_idx),
miss,
lambda: (state_, values),
hit
)
return values, idx + 1
return state_, values, idx + 1
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
state, vals, _ = jax.lax.while_loop(cond_fun, body_func, (state, ini_vals, 0))
if self.output_transform is None:
return vals[self.output_idx]
return state, vals[self.output_idx]
else:
return self.output_transform(vals[self.output_idx])
return state, self.output_transform(vals[self.output_idx])