add gene type RNN

This commit is contained in:
wls2002
2023-07-19 15:43:49 +08:00
parent 0a2a9fd1be
commit a684e6584d
18 changed files with 248 additions and 129 deletions

View File

@@ -2,4 +2,5 @@ from .base import BaseGene
from .normal import NormalGene
from .activation import Activation
from .aggregation import Aggregation
from .recurrent import RecurrentGene

View File

@@ -86,10 +86,12 @@ class NormalGene(BaseGene):
@staticmethod
def forward_transform(nodes, conns):
u_conns = unflatten_connections(nodes, conns)
u_conns = jnp.where(jnp.isnan(u_conns[0, :]), jnp.nan, u_conns) # enable is false, then the connections is nan
u_conns = u_conns[1:, :] # remove enable attr
conn_exist = jnp.any(~jnp.isnan(u_conns), axis=0)
seqs = topological_sort(nodes, conn_exist)
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
# remove enable attr
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
seqs = topological_sort(nodes, conn_enable)
return seqs, nodes, u_conns
@staticmethod
@@ -167,18 +169,8 @@ class NormalGene(BaseGene):
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
# if jnp.isin(i, input_idx):
# values = miss()
# else:
# values = hit()
return values, idx + 1
# carry = (ini_vals, 0)
# while cond_fun(carry):
# carry = body_func(carry)
# vals, _ = carry
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
return vals[output_idx]
@@ -216,7 +208,3 @@ class NormalGene(BaseGene):
)
return val

View File

@@ -0,0 +1,90 @@
import jax
from jax import Array, numpy as jnp, vmap
from .normal import NormalGene
from .activation import Activation
from .aggregation import Aggregation
from ..utils import unflatten_connections, I_INT
class RecurrentGene(NormalGene):
@staticmethod
def forward_transform(nodes, conns):
u_conns = unflatten_connections(nodes, conns)
# remove un-enable connections and remove enable attr
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
return nodes, u_conns
@staticmethod
def create_forward(config):
config['activation_funcs'] = [Activation.name2func[name] for name in config['activation_option_names']]
config['aggregation_funcs'] = [Aggregation.name2func[name] for name in config['aggregation_option_names']]
def act(idx, z):
"""
calculate activation function for each node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, config['activation_funcs'], z)
return res
def agg(idx, z):
"""
calculate activation function for inputs of node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
def all_nan():
return 0.
def not_all_nan():
return jax.lax.switch(idx, config['aggregation_funcs'], z)
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
batch_act, batch_agg = vmap(act), vmap(agg)
def forward(inputs, transform) -> Array:
"""
jax forward for single input shaped (input_num, )
nodes, connections are a single genome
:argument inputs: (input_num, )
:argument cal_seqs: (N, )
:argument nodes: (N, 5)
:argument connections: (2, N, N)
:return (output_num, )
"""
nodes, cons = transform
input_idx = config['input_idx']
output_idx = config['output_idx']
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
weights = cons[0, :]
def body_func(i, values):
values = values.at[input_idx].set(inputs)
nodes_ins = values * weights.T
values = batch_agg(nodes[:, 4], nodes_ins) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(nodes[:, 3], values) # z = act(z)
return values
# for i in range(config['activate_times']):
# vals = body_func(i, vals)
#
# return vals[output_idx]
vals = jax.lax.fori_loop(0, config['activate_times'], body_func, vals)
return vals[output_idx]
return forward