bug down! Here it can solve xor successfully!
This commit is contained in:
@@ -7,12 +7,12 @@ from numpy.typing import NDArray
|
||||
|
||||
from .aggregations import agg
|
||||
from .activations import act
|
||||
from .graph import topological_sort, batch_topological_sort, topological_sort_debug
|
||||
from .graph import topological_sort, batch_topological_sort
|
||||
from .utils import I_INT
|
||||
|
||||
|
||||
def create_forward_function(nodes: NDArray, connections: NDArray,
|
||||
N: int, input_idx: NDArray, output_idx: NDArray, batch: bool, debug: bool = False):
|
||||
N: int, input_idx: NDArray, output_idx: NDArray, batch: bool):
|
||||
"""
|
||||
create forward function for different situations
|
||||
|
||||
@@ -26,11 +26,6 @@ def create_forward_function(nodes: NDArray, connections: NDArray,
|
||||
:return:
|
||||
"""
|
||||
|
||||
if debug:
|
||||
cal_seqs = topological_sort_debug(nodes, connections)
|
||||
return lambda inputs: forward_single_debug(inputs, N, input_idx, output_idx,
|
||||
cal_seqs, nodes, connections)
|
||||
|
||||
if nodes.ndim == 2: # single genome
|
||||
cal_seqs = topological_sort(nodes, connections)
|
||||
if not batch:
|
||||
@@ -51,7 +46,6 @@ def create_forward_function(nodes: NDArray, connections: NDArray,
|
||||
raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}")
|
||||
|
||||
|
||||
# @partial(jit, static_argnames=['N', 'input_idx', 'output_idx'])
|
||||
@partial(jit, static_argnames=['N'])
|
||||
def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||
cal_seqs: Array, nodes: Array, connections: Array) -> Array:
|
||||
@@ -79,38 +73,19 @@ def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||
z = z * nodes[i, 2] + nodes[i, 1]
|
||||
z = act(nodes[i, 3], z)
|
||||
|
||||
# for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals
|
||||
new_vals = jnp.where(jnp.isnan(z), carry, carry.at[i].set(z))
|
||||
new_vals = carry.at[i].set(z)
|
||||
return new_vals
|
||||
|
||||
def miss():
|
||||
return carry
|
||||
|
||||
return jax.lax.cond(i == I_INT, miss, hit), None
|
||||
return jax.lax.cond((i == I_INT) | (jnp.isin(i, input_idx)), miss, hit), None
|
||||
|
||||
vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs)
|
||||
|
||||
return vals[output_idx]
|
||||
|
||||
|
||||
def forward_single_debug(inputs, N, input_idx, output_idx: Array, cal_seqs, nodes, connections):
|
||||
ini_vals = jnp.full((N,), jnp.nan)
|
||||
ini_vals = ini_vals.at[input_idx].set(inputs)
|
||||
vals = ini_vals
|
||||
for i in cal_seqs:
|
||||
if i == I_INT:
|
||||
break
|
||||
ins = vals * connections[0, :, i]
|
||||
z = agg(nodes[i, 4], ins)
|
||||
z = z * nodes[i, 2] + nodes[i, 1]
|
||||
z = act(nodes[i, 3], z)
|
||||
|
||||
# for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals
|
||||
vals = jnp.where(jnp.isnan(z), vals, vals.at[i].set(z))
|
||||
|
||||
return vals[output_idx]
|
||||
|
||||
|
||||
@partial(vmap, in_axes=(0, None, None, None, None, None, None))
|
||||
def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||
cal_seqs: Array, nodes: Array, connections: Array) -> Array:
|
||||
|
||||
Reference in New Issue
Block a user