All function with state will update the state and return it.
Remove randkey args in functions with state, since it can attach the randkey by states.
This commit is contained in:
@@ -38,7 +38,7 @@ class DefaultGenome(BaseGenome):
|
||||
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
||||
seqs = topological_sort(nodes, conn_enable)
|
||||
|
||||
return seqs, nodes, u_conns
|
||||
return state, seqs, nodes, u_conns
|
||||
|
||||
def forward(self, state, inputs, transformed):
|
||||
cal_seqs, nodes, conns = transformed
|
||||
@@ -49,34 +49,32 @@ class DefaultGenome(BaseGenome):
|
||||
nodes_attrs = nodes[:, 1:]
|
||||
|
||||
def cond_fun(carry):
|
||||
values, idx = carry
|
||||
state_, values, idx = carry
|
||||
return (idx < N) & (cal_seqs[idx] != I_INT)
|
||||
|
||||
def body_func(carry):
|
||||
values, idx = carry
|
||||
state_, values, idx = carry
|
||||
i = cal_seqs[idx]
|
||||
|
||||
def hit():
|
||||
ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 1, 0))(conns[:, :, i], values)
|
||||
z = self.node_gene.forward(state, nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
|
||||
s, ins = jax.vmap(self.conn_gene.forward,
|
||||
in_axes=(None, 1, 0), out_axes=(None, 0))(state_, conns[:, :, i], values)
|
||||
s, z = self.node_gene.forward(s, nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
|
||||
new_values = values.at[i].set(z)
|
||||
return new_values
|
||||
|
||||
def miss():
|
||||
return values
|
||||
return s, new_values
|
||||
|
||||
# the val of input nodes is obtained by the task, not by calculation
|
||||
values = jax.lax.cond(
|
||||
state_, values = jax.lax.cond(
|
||||
jnp.isin(i, self.input_idx),
|
||||
miss,
|
||||
lambda: (state_, values),
|
||||
hit
|
||||
)
|
||||
|
||||
return values, idx + 1
|
||||
return state_, values, idx + 1
|
||||
|
||||
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
|
||||
state, vals, _ = jax.lax.while_loop(cond_fun, body_func, (state, ini_vals, 0))
|
||||
|
||||
if self.output_transform is None:
|
||||
return vals[self.output_idx]
|
||||
return state, vals[self.output_idx]
|
||||
else:
|
||||
return self.output_transform(vals[self.output_idx])
|
||||
return state, self.output_transform(vals[self.output_idx])
|
||||
|
||||
Reference in New Issue
Block a user