refactor genome.py use (C, 4) to replace (2, N, N) to represent connections

faster, faster and faster!
This commit is contained in:
wls2002
2023-05-12 00:57:55 +08:00
parent e5fc1167d9
commit 47b1a1dbb2
16 changed files with 363 additions and 419 deletions

View File

@@ -1,51 +1,12 @@
from functools import partial
import jax
from jax import Array, numpy as jnp
from jax import jit, vmap
from numpy.typing import NDArray
from .aggregations import agg
from .activations import act
from .graph import topological_sort, batch_topological_sort
from .utils import I_INT
def create_forward_function(nodes: NDArray, connections: NDArray,
N: int, input_idx: NDArray, output_idx: NDArray, batch: bool):
"""
create forward function for different situations
:param nodes: shape (N, 5) or (pop_size, N, 5)
:param connections: shape (2, N, N) or (pop_size, 2, N, N)
:param N:
:param input_idx:
:param output_idx:
:param batch: using batch or not
:param debug: debug mode
:return:
"""
if nodes.ndim == 2: # single genome
cal_seqs = topological_sort(nodes, connections)
if not batch:
return lambda inputs: forward_single(inputs, N, input_idx, output_idx,
cal_seqs, nodes, connections)
else:
return lambda batch_inputs: forward_batch(batch_inputs, N, input_idx, output_idx,
cal_seqs, nodes, connections)
elif nodes.ndim == 3: # pop genome
pop_cal_seqs = batch_topological_sort(nodes, connections)
if not batch:
return lambda inputs: pop_forward_single(inputs, N, input_idx, output_idx,
pop_cal_seqs, nodes, connections)
else:
return lambda batch_inputs: pop_forward_batch(batch_inputs, N, input_idx, output_idx,
pop_cal_seqs, nodes, connections)
else:
raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}")
# TODO: enabled information doesn't influence forward. That is wrong!
@jit
def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Array,
input_idx: Array, output_idx: Array) -> Array:
@@ -84,66 +45,3 @@ def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Ar
vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs)
return vals[output_idx]
# @partial(jit, static_argnames=['N'])
# @partial(vmap, in_axes=(0, None, None, None, None, None, None))
# def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
# cal_seqs: Array, nodes: Array, connections: Array) -> Array:
# """
# jax forward for batch_inputs shaped (batch_size, input_num)
# nodes, connections are single genome
#
# :argument batch_inputs: (batch_size, input_num)
# :argument N: int
# :argument input_idx: (input_num, )
# :argument output_idx: (output_num, )
# :argument cal_seqs: (N, )
# :argument nodes: (N, 5)
# :argument connections: (2, N, N)
#
# :return (batch_size, output_num)
# """
# return forward_single(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections)
#
#
# @partial(jit, static_argnames=['N'])
# @partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
# def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
# pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
# """
# jax forward for single input shaped (input_num, )
# pop_nodes, pop_connections are population of genomes
#
# :argument inputs: (input_num, )
# :argument N: int
# :argument input_idx: (input_num, )
# :argument output_idx: (output_num, )
# :argument pop_cal_seqs: (pop_size, N)
# :argument pop_nodes: (pop_size, N, 5)
# :argument pop_connections: (pop_size, 2, N, N)
#
# :return (pop_size, output_num)
# """
# return forward_single(inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
#
#
# @partial(jit, static_argnames=['N'])
# @partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
# def pop_forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
# pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
# """
# jax forward for batch input shaped (batch, input_num)
# pop_nodes, pop_connections are population of genomes
#
# :argument batch_inputs: (batch_size, input_num)
# :argument N: int
# :argument input_idx: (input_num, )
# :argument output_idx: (output_num, )
# :argument pop_cal_seqs: (pop_size, N)
# :argument pop_nodes: (pop_size, N, 5)
# :argument pop_connections: (pop_size, 2, N, N)
#
# :return (pop_size, batch_size, output_num)
# """
# return forward_batch(batch_inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)