91 lines
2.9 KiB
Python
91 lines
2.9 KiB
Python
import jax
|
|
from jax import Array, numpy as jnp, vmap
|
|
|
|
from .normal import NormalGene
|
|
from .activation import Activation
|
|
from .aggregation import Aggregation
|
|
from algorithm.utils import unflatten_connections
|
|
|
|
|
|
class RecurrentGene(NormalGene):
|
|
|
|
@staticmethod
|
|
def forward_transform(state, nodes, conns):
|
|
u_conns = unflatten_connections(nodes, conns)
|
|
|
|
# remove un-enable connections and remove enable attr
|
|
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
|
|
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
|
|
|
return nodes, u_conns
|
|
|
|
@staticmethod
|
|
def create_forward(config):
|
|
config['activation_funcs'] = [Activation.name2func[name] for name in config['activation_option_names']]
|
|
config['aggregation_funcs'] = [Aggregation.name2func[name] for name in config['aggregation_option_names']]
|
|
|
|
def act(idx, z):
|
|
"""
|
|
calculate activation function for each node
|
|
"""
|
|
idx = jnp.asarray(idx, dtype=jnp.int32)
|
|
# change idx from float to int
|
|
res = jax.lax.switch(idx, config['activation_funcs'], z)
|
|
return res
|
|
|
|
def agg(idx, z):
|
|
"""
|
|
calculate activation function for inputs of node
|
|
"""
|
|
idx = jnp.asarray(idx, dtype=jnp.int32)
|
|
|
|
def all_nan():
|
|
return 0.
|
|
|
|
def not_all_nan():
|
|
return jax.lax.switch(idx, config['aggregation_funcs'], z)
|
|
|
|
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
|
|
|
|
batch_act, batch_agg = vmap(act), vmap(agg)
|
|
|
|
def forward(inputs, transform) -> Array:
|
|
"""
|
|
jax forward for single input shaped (input_num, )
|
|
nodes, connections are a single genome
|
|
|
|
:argument inputs: (input_num, )
|
|
:argument cal_seqs: (N, )
|
|
:argument nodes: (N, 5)
|
|
:argument connections: (2, N, N)
|
|
|
|
:return (output_num, )
|
|
"""
|
|
|
|
nodes, cons = transform
|
|
|
|
input_idx = config['input_idx']
|
|
output_idx = config['output_idx']
|
|
|
|
N = nodes.shape[0]
|
|
vals = jnp.full((N,), 0.)
|
|
|
|
weights = cons[0, :]
|
|
|
|
def body_func(i, values):
|
|
values = values.at[input_idx].set(inputs)
|
|
nodes_ins = values * weights.T
|
|
values = batch_agg(nodes[:, 4], nodes_ins) # z = agg(ins)
|
|
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
|
|
values = batch_act(nodes[:, 3], values) # z = act(z)
|
|
return values
|
|
|
|
# for i in range(config['activate_times']):
|
|
# vals = body_func(i, vals)
|
|
#
|
|
# return vals[output_idx]
|
|
vals = jax.lax.fori_loop(0, config['activate_times'], body_func, vals)
|
|
return vals[output_idx]
|
|
|
|
return forward
|