add function to put **all** compilation at the beginning of the execution.
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
from .genome import expand, expand_single, pop_analysis, initialize_genomes
|
||||
from .forward import create_forward_function
|
||||
from .forward import create_forward_function, forward_single
|
||||
from .activations import act_name2key
|
||||
from .aggregations import agg_name2key
|
||||
from .crossover import crossover
|
||||
from .mutate import mutate
|
||||
from .distance import distance
|
||||
from .graph import topological_sort
|
||||
|
||||
@@ -46,15 +46,14 @@ def create_forward_function(nodes: NDArray, connections: NDArray,
|
||||
raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}")
|
||||
|
||||
|
||||
@partial(jit, static_argnames=['N'])
|
||||
def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||
cal_seqs: Array, nodes: Array, connections: Array) -> Array:
|
||||
@jit
|
||||
def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Array,
|
||||
input_idx: Array, output_idx: Array) -> Array:
|
||||
"""
|
||||
jax forward for single input shaped (input_num, )
|
||||
nodes, connections are single genome
|
||||
|
||||
:argument inputs: (input_num, )
|
||||
:argument N: int
|
||||
:argument input_idx: (input_num, )
|
||||
:argument output_idx: (output_num, )
|
||||
:argument cal_seqs: (N, )
|
||||
@@ -63,6 +62,7 @@ def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||
|
||||
:return (output_num, )
|
||||
"""
|
||||
N = nodes.shape[0]
|
||||
ini_vals = jnp.full((N,), jnp.nan)
|
||||
ini_vals = ini_vals.at[input_idx].set(inputs)
|
||||
|
||||
@@ -86,64 +86,64 @@ def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||
return vals[output_idx]
|
||||
|
||||
|
||||
@partial(jit, static_argnames=['N'])
|
||||
@partial(vmap, in_axes=(0, None, None, None, None, None, None))
|
||||
def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||
cal_seqs: Array, nodes: Array, connections: Array) -> Array:
|
||||
"""
|
||||
jax forward for batch_inputs shaped (batch_size, input_num)
|
||||
nodes, connections are single genome
|
||||
|
||||
:argument batch_inputs: (batch_size, input_num)
|
||||
:argument N: int
|
||||
:argument input_idx: (input_num, )
|
||||
:argument output_idx: (output_num, )
|
||||
:argument cal_seqs: (N, )
|
||||
:argument nodes: (N, 5)
|
||||
:argument connections: (2, N, N)
|
||||
|
||||
:return (batch_size, output_num)
|
||||
"""
|
||||
return forward_single(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections)
|
||||
|
||||
|
||||
@partial(jit, static_argnames=['N'])
|
||||
@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
|
||||
def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||
pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
|
||||
"""
|
||||
jax forward for single input shaped (input_num, )
|
||||
pop_nodes, pop_connections are population of genomes
|
||||
|
||||
:argument inputs: (input_num, )
|
||||
:argument N: int
|
||||
:argument input_idx: (input_num, )
|
||||
:argument output_idx: (output_num, )
|
||||
:argument pop_cal_seqs: (pop_size, N)
|
||||
:argument pop_nodes: (pop_size, N, 5)
|
||||
:argument pop_connections: (pop_size, 2, N, N)
|
||||
|
||||
:return (pop_size, output_num)
|
||||
"""
|
||||
return forward_single(inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
|
||||
|
||||
|
||||
@partial(jit, static_argnames=['N'])
|
||||
@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
|
||||
def pop_forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||
pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
|
||||
"""
|
||||
jax forward for batch input shaped (batch, input_num)
|
||||
pop_nodes, pop_connections are population of genomes
|
||||
|
||||
:argument batch_inputs: (batch_size, input_num)
|
||||
:argument N: int
|
||||
:argument input_idx: (input_num, )
|
||||
:argument output_idx: (output_num, )
|
||||
:argument pop_cal_seqs: (pop_size, N)
|
||||
:argument pop_nodes: (pop_size, N, 5)
|
||||
:argument pop_connections: (pop_size, 2, N, N)
|
||||
|
||||
:return (pop_size, batch_size, output_num)
|
||||
"""
|
||||
return forward_batch(batch_inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
|
||||
# @partial(jit, static_argnames=['N'])
|
||||
# @partial(vmap, in_axes=(0, None, None, None, None, None, None))
|
||||
# def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||
# cal_seqs: Array, nodes: Array, connections: Array) -> Array:
|
||||
# """
|
||||
# jax forward for batch_inputs shaped (batch_size, input_num)
|
||||
# nodes, connections are single genome
|
||||
#
|
||||
# :argument batch_inputs: (batch_size, input_num)
|
||||
# :argument N: int
|
||||
# :argument input_idx: (input_num, )
|
||||
# :argument output_idx: (output_num, )
|
||||
# :argument cal_seqs: (N, )
|
||||
# :argument nodes: (N, 5)
|
||||
# :argument connections: (2, N, N)
|
||||
#
|
||||
# :return (batch_size, output_num)
|
||||
# """
|
||||
# return forward_single(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections)
|
||||
#
|
||||
#
|
||||
# @partial(jit, static_argnames=['N'])
|
||||
# @partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
|
||||
# def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||
# pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
|
||||
# """
|
||||
# jax forward for single input shaped (input_num, )
|
||||
# pop_nodes, pop_connections are population of genomes
|
||||
#
|
||||
# :argument inputs: (input_num, )
|
||||
# :argument N: int
|
||||
# :argument input_idx: (input_num, )
|
||||
# :argument output_idx: (output_num, )
|
||||
# :argument pop_cal_seqs: (pop_size, N)
|
||||
# :argument pop_nodes: (pop_size, N, 5)
|
||||
# :argument pop_connections: (pop_size, 2, N, N)
|
||||
#
|
||||
# :return (pop_size, output_num)
|
||||
# """
|
||||
# return forward_single(inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
|
||||
#
|
||||
#
|
||||
# @partial(jit, static_argnames=['N'])
|
||||
# @partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
|
||||
# def pop_forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||
# pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
|
||||
# """
|
||||
# jax forward for batch input shaped (batch, input_num)
|
||||
# pop_nodes, pop_connections are population of genomes
|
||||
#
|
||||
# :argument batch_inputs: (batch_size, input_num)
|
||||
# :argument N: int
|
||||
# :argument input_idx: (input_num, )
|
||||
# :argument output_idx: (output_num, )
|
||||
# :argument pop_cal_seqs: (pop_size, N)
|
||||
# :argument pop_nodes: (pop_size, N, 5)
|
||||
# :argument pop_connections: (pop_size, 2, N, N)
|
||||
#
|
||||
# :return (pop_size, batch_size, output_num)
|
||||
# """
|
||||
# return forward_batch(batch_inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
|
||||
|
||||
Reference in New Issue
Block a user