49 lines
1.3 KiB
Python
49 lines
1.3 KiB
Python
import jax
|
|
from jax import Array, numpy as jnp
|
|
from jax import jit, vmap
|
|
|
|
from .aggregations import agg
|
|
from .activations import act
|
|
from .utils import I_INT
|
|
|
|
|
|
# TODO: enabled information doesn't influence forward. That is wrong!
|
|
@jit
|
|
def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Array,
|
|
input_idx: Array, output_idx: Array) -> Array:
|
|
"""
|
|
jax forward for single input shaped (input_num, )
|
|
nodes, connections are single genome
|
|
|
|
:argument inputs: (input_num, )
|
|
:argument input_idx: (input_num, )
|
|
:argument output_idx: (output_num, )
|
|
:argument cal_seqs: (N, )
|
|
:argument nodes: (N, 5)
|
|
:argument connections: (2, N, N)
|
|
|
|
:return (output_num, )
|
|
"""
|
|
N = nodes.shape[0]
|
|
ini_vals = jnp.full((N,), jnp.nan)
|
|
ini_vals = ini_vals.at[input_idx].set(inputs)
|
|
|
|
def scan_body(carry, i):
|
|
def hit():
|
|
ins = carry * connections[0, :, i]
|
|
z = agg(nodes[i, 4], ins)
|
|
z = z * nodes[i, 2] + nodes[i, 1]
|
|
z = act(nodes[i, 3], z)
|
|
|
|
new_vals = carry.at[i].set(z)
|
|
return new_vals
|
|
|
|
def miss():
|
|
return carry
|
|
|
|
return jax.lax.cond((i == I_INT) | (jnp.isin(i, input_idx)), miss, hit), None
|
|
|
|
vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs)
|
|
|
|
return vals[output_idx]
|