add gene type RNN

This commit is contained in:
wls2002
2023-07-19 15:43:49 +08:00
parent 0a2a9fd1be
commit a684e6584d
18 changed files with 248 additions and 129 deletions

View File

@@ -86,10 +86,12 @@ class NormalGene(BaseGene):
@staticmethod
def forward_transform(nodes, conns):
u_conns = unflatten_connections(nodes, conns)
u_conns = jnp.where(jnp.isnan(u_conns[0, :]), jnp.nan, u_conns) # enable is false, then the connections is nan
u_conns = u_conns[1:, :] # remove enable attr
conn_exist = jnp.any(~jnp.isnan(u_conns), axis=0)
seqs = topological_sort(nodes, conn_exist)
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
# remove enable attr
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
seqs = topological_sort(nodes, conn_enable)
return seqs, nodes, u_conns
@staticmethod
@@ -167,18 +169,8 @@ class NormalGene(BaseGene):
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
# if jnp.isin(i, input_idx):
# values = miss()
# else:
# values = hit()
return values, idx + 1
# carry = (ini_vals, 0)
# while cond_fun(carry):
# carry = body_func(carry)
# vals, _ = carry
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
return vals[output_idx]
@@ -216,7 +208,3 @@ class NormalGene(BaseGene):
)
return val