Files
tensorneat-mend/algorithms/neat/genome/forward.py
2023-05-06 18:33:30 +08:00

172 lines
6.3 KiB
Python

from functools import partial
import jax
from jax import Array, numpy as jnp
from jax import jit, vmap
from numpy.typing import NDArray
from .aggregations import agg
from .activations import act
from .graph import topological_sort, batch_topological_sort, topological_sort_debug
from .utils import I_INT
def create_forward_function(nodes: NDArray, connections: NDArray,
N: int, input_idx: NDArray, output_idx: NDArray, batch: bool, debug: bool = False):
"""
create forward function for different situations
:param nodes: shape (N, 5) or (pop_size, N, 5)
:param connections: shape (2, N, N) or (pop_size, 2, N, N)
:param N:
:param input_idx:
:param output_idx:
:param batch: using batch or not
:param debug: debug mode
:return:
"""
if debug:
cal_seqs = topological_sort_debug(nodes, connections)
return lambda inputs: forward_single_debug(inputs, N, input_idx, output_idx,
cal_seqs, nodes, connections)
if nodes.ndim == 2: # single genome
cal_seqs = topological_sort(nodes, connections)
if not batch:
return lambda inputs: forward_single(inputs, N, input_idx, output_idx,
cal_seqs, nodes, connections)
else:
return lambda batch_inputs: forward_batch(batch_inputs, N, input_idx, output_idx,
cal_seqs, nodes, connections)
elif nodes.ndim == 3: # pop genome
pop_cal_seqs = batch_topological_sort(nodes, connections)
if not batch:
return lambda inputs: pop_forward_single(inputs, N, input_idx, output_idx,
pop_cal_seqs, nodes, connections)
else:
return lambda batch_inputs: pop_forward_batch(batch_inputs, N, input_idx, output_idx,
pop_cal_seqs, nodes, connections)
else:
raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}")
# @partial(jit, static_argnames=['N', 'input_idx', 'output_idx'])
@partial(jit, static_argnames=['N'])
def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
cal_seqs: Array, nodes: Array, connections: Array) -> Array:
"""
jax forward for single input shaped (input_num, )
nodes, connections are single genome
:argument inputs: (input_num, )
:argument N: int
:argument input_idx: (input_num, )
:argument output_idx: (output_num, )
:argument cal_seqs: (N, )
:argument nodes: (N, 5)
:argument connections: (2, N, N)
:return (output_num, )
"""
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
def scan_body(carry, i):
def hit():
ins = carry * connections[0, :, i]
z = agg(nodes[i, 4], ins)
z = z * nodes[i, 2] + nodes[i, 1]
z = act(nodes[i, 3], z)
# for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals
new_vals = jnp.where(jnp.isnan(z), carry, carry.at[i].set(z))
return new_vals
def miss():
return carry
return jax.lax.cond(i == I_INT, miss, hit), None
vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs)
return vals[output_idx]
def forward_single_debug(inputs, N, input_idx, output_idx: Array, cal_seqs, nodes, connections):
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
vals = ini_vals
for i in cal_seqs:
if i == I_INT:
break
ins = vals * connections[0, :, i]
z = agg(nodes[i, 4], ins)
z = z * nodes[i, 2] + nodes[i, 1]
z = act(nodes[i, 3], z)
# for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals
vals = jnp.where(jnp.isnan(z), vals, vals.at[i].set(z))
return vals[output_idx]
@partial(vmap, in_axes=(0, None, None, None, None, None, None))
def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
cal_seqs: Array, nodes: Array, connections: Array) -> Array:
"""
jax forward for batch_inputs shaped (batch_size, input_num)
nodes, connections are single genome
:argument batch_inputs: (batch_size, input_num)
:argument N: int
:argument input_idx: (input_num, )
:argument output_idx: (output_num, )
:argument cal_seqs: (N, )
:argument nodes: (N, 5)
:argument connections: (2, N, N)
:return (batch_size, output_num)
"""
return forward_single(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections)
@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
"""
jax forward for single input shaped (input_num, )
pop_nodes, pop_connections are population of genomes
:argument inputs: (input_num, )
:argument N: int
:argument input_idx: (input_num, )
:argument output_idx: (output_num, )
:argument pop_cal_seqs: (pop_size, N)
:argument pop_nodes: (pop_size, N, 5)
:argument pop_connections: (pop_size, 2, N, N)
:return (pop_size, output_num)
"""
return forward_single(inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
def pop_forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
"""
jax forward for batch input shaped (batch, input_num)
pop_nodes, pop_connections are population of genomes
:argument batch_inputs: (batch_size, input_num)
:argument N: int
:argument input_idx: (input_num, )
:argument output_idx: (output_num, )
:argument pop_cal_seqs: (pop_size, N)
:argument pop_nodes: (pop_size, N, 5)
:argument pop_connections: (pop_size, 2, N, N)
:return (pop_size, batch_size, output_num)
"""
return forward_batch(batch_inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)