bug down! Here it can solve xor successfully!

This commit is contained in:
wls2002
2023-05-07 16:03:52 +08:00
parent d1f54022bd
commit a3b9bca866
12 changed files with 120 additions and 254 deletions

View File

@@ -7,12 +7,12 @@ from numpy.typing import NDArray
from .aggregations import agg
from .activations import act
from .graph import topological_sort, batch_topological_sort, topological_sort_debug
from .graph import topological_sort, batch_topological_sort
from .utils import I_INT
def create_forward_function(nodes: NDArray, connections: NDArray,
N: int, input_idx: NDArray, output_idx: NDArray, batch: bool, debug: bool = False):
N: int, input_idx: NDArray, output_idx: NDArray, batch: bool):
"""
create forward function for different situations
@@ -26,11 +26,6 @@ def create_forward_function(nodes: NDArray, connections: NDArray,
:return:
"""
if debug:
cal_seqs = topological_sort_debug(nodes, connections)
return lambda inputs: forward_single_debug(inputs, N, input_idx, output_idx,
cal_seqs, nodes, connections)
if nodes.ndim == 2: # single genome
cal_seqs = topological_sort(nodes, connections)
if not batch:
@@ -51,7 +46,6 @@ def create_forward_function(nodes: NDArray, connections: NDArray,
raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}")
# @partial(jit, static_argnames=['N', 'input_idx', 'output_idx'])
@partial(jit, static_argnames=['N'])
def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
cal_seqs: Array, nodes: Array, connections: Array) -> Array:
@@ -79,38 +73,19 @@ def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
z = z * nodes[i, 2] + nodes[i, 1]
z = act(nodes[i, 3], z)
# for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals
new_vals = jnp.where(jnp.isnan(z), carry, carry.at[i].set(z))
new_vals = carry.at[i].set(z)
return new_vals
def miss():
return carry
return jax.lax.cond(i == I_INT, miss, hit), None
return jax.lax.cond((i == I_INT) | (jnp.isin(i, input_idx)), miss, hit), None
vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs)
return vals[output_idx]
def forward_single_debug(inputs, N, input_idx, output_idx: Array, cal_seqs, nodes, connections):
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
vals = ini_vals
for i in cal_seqs:
if i == I_INT:
break
ins = vals * connections[0, :, i]
z = agg(nodes[i, 4], ins)
z = z * nodes[i, 2] + nodes[i, 1]
z = act(nodes[i, 3], z)
# for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals
vals = jnp.where(jnp.isnan(z), vals, vals.at[i].set(z))
return vals[output_idx]
@partial(vmap, in_axes=(0, None, None, None, None, None, None))
def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
cal_seqs: Array, nodes: Array, connections: Array) -> Array: