add gene type RNN
This commit is contained in:
@@ -86,10 +86,12 @@ class NormalGene(BaseGene):
|
||||
@staticmethod
|
||||
def forward_transform(nodes, conns):
|
||||
u_conns = unflatten_connections(nodes, conns)
|
||||
u_conns = jnp.where(jnp.isnan(u_conns[0, :]), jnp.nan, u_conns) # enable is false, then the connections is nan
|
||||
u_conns = u_conns[1:, :] # remove enable attr
|
||||
conn_exist = jnp.any(~jnp.isnan(u_conns), axis=0)
|
||||
seqs = topological_sort(nodes, conn_exist)
|
||||
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
|
||||
|
||||
# remove enable attr
|
||||
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
||||
seqs = topological_sort(nodes, conn_enable)
|
||||
|
||||
return seqs, nodes, u_conns
|
||||
|
||||
@staticmethod
|
||||
@@ -167,18 +169,8 @@ class NormalGene(BaseGene):
|
||||
# the val of input nodes is obtained by the task, not by calculation
|
||||
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
|
||||
|
||||
# if jnp.isin(i, input_idx):
|
||||
# values = miss()
|
||||
# else:
|
||||
# values = hit()
|
||||
|
||||
return values, idx + 1
|
||||
|
||||
# carry = (ini_vals, 0)
|
||||
# while cond_fun(carry):
|
||||
# carry = body_func(carry)
|
||||
# vals, _ = carry
|
||||
|
||||
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
|
||||
|
||||
return vals[output_idx]
|
||||
@@ -216,7 +208,3 @@ class NormalGene(BaseGene):
|
||||
)
|
||||
|
||||
return val
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user