All function with state will update the state and return it.

Remove randkey args in functions with state, since it can attach the randkey by states.
This commit is contained in:
wls2002
2024-05-25 20:45:57 +08:00
parent 5626fddf41
commit 79d53ea7af
12 changed files with 84 additions and 70 deletions

View File

@@ -24,7 +24,7 @@ class BaseGenome:
self.node_gene = node_gene
self.conn_gene = conn_gene
def setup(self, key, state=State()):
def setup(self, state=State()):
return state
def transform(self, state, nodes, conns):

View File

@@ -38,7 +38,7 @@ class DefaultGenome(BaseGenome):
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
seqs = topological_sort(nodes, conn_enable)
return seqs, nodes, u_conns
return state, seqs, nodes, u_conns
def forward(self, state, inputs, transformed):
cal_seqs, nodes, conns = transformed
@@ -49,34 +49,32 @@ class DefaultGenome(BaseGenome):
nodes_attrs = nodes[:, 1:]
def cond_fun(carry):
values, idx = carry
state_, values, idx = carry
return (idx < N) & (cal_seqs[idx] != I_INT)
def body_func(carry):
values, idx = carry
state_, values, idx = carry
i = cal_seqs[idx]
def hit():
ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 1, 0))(conns[:, :, i], values)
z = self.node_gene.forward(state, nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
s, ins = jax.vmap(self.conn_gene.forward,
in_axes=(None, 1, 0), out_axes=(None, 0))(state_, conns[:, :, i], values)
s, z = self.node_gene.forward(s, nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
new_values = values.at[i].set(z)
return new_values
def miss():
return values
return s, new_values
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(
state_, values = jax.lax.cond(
jnp.isin(i, self.input_idx),
miss,
lambda: (state_, values),
hit
)
return values, idx + 1
return state_, values, idx + 1
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
state, vals, _ = jax.lax.while_loop(cond_fun, body_func, (state, ini_vals, 0))
if self.output_transform is None:
return vals[self.output_idx]
return state, vals[self.output_idx]
else:
return self.output_transform(vals[self.output_idx])
return state, self.output_transform(vals[self.output_idx])

View File

@@ -39,7 +39,7 @@ class RecurrentGenome(BaseGenome):
conn_enable = u_conns[0] == 1
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
return nodes, u_conns
return state, nodes, u_conns
def forward(self, state, inputs, transformed):
nodes, conns = transformed
@@ -48,27 +48,36 @@ class RecurrentGenome(BaseGenome):
vals = jnp.full((N,), jnp.nan)
nodes_attrs = nodes[:, 1:]
def body_func(_, values):
def body_func(_, carry):
state_, values = carry
# set input values
values = values.at[self.input_idx].set(inputs)
# calculate connections
node_ins = jax.vmap(
state_, node_ins = jax.vmap(
jax.vmap(
self.conn_gene.forward,
in_axes=(None, 1, None)
in_axes=(None, 1, None),
out_axes=(None, 0)
),
in_axes=(None, 1, 0)
)(state, conns, values)
in_axes=(None, 1, 0),
out_axes=(None, 0)
)(state_, conns, values)
# calculate nodes
is_output_nodes = jnp.isin(
jnp.arange(N),
self.output_idx
)
values = jax.vmap(self.node_gene.forward, in_axes=(None, 0, 0))(nodes_attrs, node_ins.T, is_output_nodes)
return values
state_, values = jax.vmap(
self.node_gene.forward,
in_axes=(None, 0, 0, 0),
out_axes=(None, 0)
)(state_, nodes_attrs, node_ins.T, is_output_nodes)
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)
return state_, values
return vals[self.output_idx]
state, vals = jax.lax.fori_loop(0, self.activate_time, body_func, (state, vals))
return state, vals[self.output_idx]