change project structure and using .ini as config file
This commit is contained in:
@@ -1,47 +0,0 @@
|
||||
import jax
|
||||
from jax import Array, numpy as jnp
|
||||
from jax import jit, vmap
|
||||
|
||||
from .aggregations import agg
|
||||
from .activations import act
|
||||
from .utils import I_INT
|
||||
|
||||
# TODO: enabled information doesn't influence forward. That is wrong!
|
||||
@jit
|
||||
def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Array,
|
||||
input_idx: Array, output_idx: Array) -> Array:
|
||||
"""
|
||||
jax forward for single input shaped (input_num, )
|
||||
nodes, connections are single genome
|
||||
|
||||
:argument inputs: (input_num, )
|
||||
:argument input_idx: (input_num, )
|
||||
:argument output_idx: (output_num, )
|
||||
:argument cal_seqs: (N, )
|
||||
:argument nodes: (N, 5)
|
||||
:argument connections: (2, N, N)
|
||||
|
||||
:return (output_num, )
|
||||
"""
|
||||
N = nodes.shape[0]
|
||||
ini_vals = jnp.full((N,), jnp.nan)
|
||||
ini_vals = ini_vals.at[input_idx].set(inputs)
|
||||
|
||||
def scan_body(carry, i):
|
||||
def hit():
|
||||
ins = carry * connections[0, :, i]
|
||||
z = agg(nodes[i, 4], ins)
|
||||
z = z * nodes[i, 2] + nodes[i, 1]
|
||||
z = act(nodes[i, 3], z)
|
||||
|
||||
new_vals = carry.at[i].set(z)
|
||||
return new_vals
|
||||
|
||||
def miss():
|
||||
return carry
|
||||
|
||||
return jax.lax.cond((i == I_INT) | (jnp.isin(i, input_idx)), miss, hit), None
|
||||
|
||||
vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs)
|
||||
|
||||
return vals[output_idx]
|
||||
Reference in New Issue
Block a user