All function with state will update the state and return it.
Remove randkey args in functions with state, since it can attach the randkey by states.
This commit is contained in:
@@ -24,7 +24,7 @@ class BaseGenome:
|
||||
self.node_gene = node_gene
|
||||
self.conn_gene = conn_gene
|
||||
|
||||
def setup(self, key, state=State()):
|
||||
def setup(self, state=State()):
|
||||
return state
|
||||
|
||||
def transform(self, state, nodes, conns):
|
||||
|
||||
@@ -38,7 +38,7 @@ class DefaultGenome(BaseGenome):
|
||||
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
||||
seqs = topological_sort(nodes, conn_enable)
|
||||
|
||||
return seqs, nodes, u_conns
|
||||
return state, seqs, nodes, u_conns
|
||||
|
||||
def forward(self, state, inputs, transformed):
|
||||
cal_seqs, nodes, conns = transformed
|
||||
@@ -49,34 +49,32 @@ class DefaultGenome(BaseGenome):
|
||||
nodes_attrs = nodes[:, 1:]
|
||||
|
||||
def cond_fun(carry):
|
||||
values, idx = carry
|
||||
state_, values, idx = carry
|
||||
return (idx < N) & (cal_seqs[idx] != I_INT)
|
||||
|
||||
def body_func(carry):
|
||||
values, idx = carry
|
||||
state_, values, idx = carry
|
||||
i = cal_seqs[idx]
|
||||
|
||||
def hit():
|
||||
ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 1, 0))(conns[:, :, i], values)
|
||||
z = self.node_gene.forward(state, nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
|
||||
s, ins = jax.vmap(self.conn_gene.forward,
|
||||
in_axes=(None, 1, 0), out_axes=(None, 0))(state_, conns[:, :, i], values)
|
||||
s, z = self.node_gene.forward(s, nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
|
||||
new_values = values.at[i].set(z)
|
||||
return new_values
|
||||
|
||||
def miss():
|
||||
return values
|
||||
return s, new_values
|
||||
|
||||
# the val of input nodes is obtained by the task, not by calculation
|
||||
values = jax.lax.cond(
|
||||
state_, values = jax.lax.cond(
|
||||
jnp.isin(i, self.input_idx),
|
||||
miss,
|
||||
lambda: (state_, values),
|
||||
hit
|
||||
)
|
||||
|
||||
return values, idx + 1
|
||||
return state_, values, idx + 1
|
||||
|
||||
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
|
||||
state, vals, _ = jax.lax.while_loop(cond_fun, body_func, (state, ini_vals, 0))
|
||||
|
||||
if self.output_transform is None:
|
||||
return vals[self.output_idx]
|
||||
return state, vals[self.output_idx]
|
||||
else:
|
||||
return self.output_transform(vals[self.output_idx])
|
||||
return state, self.output_transform(vals[self.output_idx])
|
||||
|
||||
@@ -39,7 +39,7 @@ class RecurrentGenome(BaseGenome):
|
||||
conn_enable = u_conns[0] == 1
|
||||
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
||||
|
||||
return nodes, u_conns
|
||||
return state, nodes, u_conns
|
||||
|
||||
def forward(self, state, inputs, transformed):
|
||||
nodes, conns = transformed
|
||||
@@ -48,27 +48,36 @@ class RecurrentGenome(BaseGenome):
|
||||
vals = jnp.full((N,), jnp.nan)
|
||||
nodes_attrs = nodes[:, 1:]
|
||||
|
||||
def body_func(_, values):
|
||||
def body_func(_, carry):
|
||||
state_, values = carry
|
||||
|
||||
# set input values
|
||||
values = values.at[self.input_idx].set(inputs)
|
||||
|
||||
# calculate connections
|
||||
node_ins = jax.vmap(
|
||||
state_, node_ins = jax.vmap(
|
||||
jax.vmap(
|
||||
self.conn_gene.forward,
|
||||
in_axes=(None, 1, None)
|
||||
in_axes=(None, 1, None),
|
||||
out_axes=(None, 0)
|
||||
),
|
||||
in_axes=(None, 1, 0)
|
||||
)(state, conns, values)
|
||||
in_axes=(None, 1, 0),
|
||||
out_axes=(None, 0)
|
||||
)(state_, conns, values)
|
||||
|
||||
# calculate nodes
|
||||
is_output_nodes = jnp.isin(
|
||||
jnp.arange(N),
|
||||
self.output_idx
|
||||
)
|
||||
values = jax.vmap(self.node_gene.forward, in_axes=(None, 0, 0))(nodes_attrs, node_ins.T, is_output_nodes)
|
||||
return values
|
||||
state_, values = jax.vmap(
|
||||
self.node_gene.forward,
|
||||
in_axes=(None, 0, 0, 0),
|
||||
out_axes=(None, 0)
|
||||
)(state_, nodes_attrs, node_ins.T, is_output_nodes)
|
||||
|
||||
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)
|
||||
return state_, values
|
||||
|
||||
return vals[self.output_idx]
|
||||
state, vals = jax.lax.fori_loop(0, self.activate_time, body_func, (state, vals))
|
||||
|
||||
return state, vals[self.output_idx]
|
||||
|
||||
Reference in New Issue
Block a user