add function to put **all** compilation at the beginning of the execution.

This commit is contained in:
wls2002
2023-05-09 02:55:47 +08:00
parent 1f2327bbd6
commit 0fdc856f2d
6 changed files with 183 additions and 75 deletions

View File

@@ -46,15 +46,14 @@ def create_forward_function(nodes: NDArray, connections: NDArray,
raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}")
@partial(jit, static_argnames=['N'])
def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
cal_seqs: Array, nodes: Array, connections: Array) -> Array:
@jit
def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Array,
input_idx: Array, output_idx: Array) -> Array:
"""
jax forward for single input shaped (input_num, )
nodes, connections are single genome
:argument inputs: (input_num, )
:argument N: int
:argument input_idx: (input_num, )
:argument output_idx: (output_num, )
:argument cal_seqs: (N, )
@@ -63,6 +62,7 @@ def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
:return (output_num, )
"""
N = nodes.shape[0]
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
@@ -86,64 +86,64 @@ def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
return vals[output_idx]
@partial(jit, static_argnames=['N'])
@partial(vmap, in_axes=(0, None, None, None, None, None, None))
def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
cal_seqs: Array, nodes: Array, connections: Array) -> Array:
"""
jax forward for batch_inputs shaped (batch_size, input_num)
nodes, connections are single genome
:argument batch_inputs: (batch_size, input_num)
:argument N: int
:argument input_idx: (input_num, )
:argument output_idx: (output_num, )
:argument cal_seqs: (N, )
:argument nodes: (N, 5)
:argument connections: (2, N, N)
:return (batch_size, output_num)
"""
return forward_single(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections)
@partial(jit, static_argnames=['N'])
@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
"""
jax forward for single input shaped (input_num, )
pop_nodes, pop_connections are population of genomes
:argument inputs: (input_num, )
:argument N: int
:argument input_idx: (input_num, )
:argument output_idx: (output_num, )
:argument pop_cal_seqs: (pop_size, N)
:argument pop_nodes: (pop_size, N, 5)
:argument pop_connections: (pop_size, 2, N, N)
:return (pop_size, output_num)
"""
return forward_single(inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
@partial(jit, static_argnames=['N'])
@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
def pop_forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
"""
jax forward for batch input shaped (batch, input_num)
pop_nodes, pop_connections are population of genomes
:argument batch_inputs: (batch_size, input_num)
:argument N: int
:argument input_idx: (input_num, )
:argument output_idx: (output_num, )
:argument pop_cal_seqs: (pop_size, N)
:argument pop_nodes: (pop_size, N, 5)
:argument pop_connections: (pop_size, 2, N, N)
:return (pop_size, batch_size, output_num)
"""
return forward_batch(batch_inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
# @partial(jit, static_argnames=['N'])
# @partial(vmap, in_axes=(0, None, None, None, None, None, None))
# def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
# cal_seqs: Array, nodes: Array, connections: Array) -> Array:
# """
# jax forward for batch_inputs shaped (batch_size, input_num)
# nodes, connections are single genome
#
# :argument batch_inputs: (batch_size, input_num)
# :argument N: int
# :argument input_idx: (input_num, )
# :argument output_idx: (output_num, )
# :argument cal_seqs: (N, )
# :argument nodes: (N, 5)
# :argument connections: (2, N, N)
#
# :return (batch_size, output_num)
# """
# return forward_single(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections)
#
#
# @partial(jit, static_argnames=['N'])
# @partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
# def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
# pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
# """
# jax forward for single input shaped (input_num, )
# pop_nodes, pop_connections are population of genomes
#
# :argument inputs: (input_num, )
# :argument N: int
# :argument input_idx: (input_num, )
# :argument output_idx: (output_num, )
# :argument pop_cal_seqs: (pop_size, N)
# :argument pop_nodes: (pop_size, N, 5)
# :argument pop_connections: (pop_size, 2, N, N)
#
# :return (pop_size, output_num)
# """
# return forward_single(inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
#
#
# @partial(jit, static_argnames=['N'])
# @partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
# def pop_forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
# pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
# """
# jax forward for batch input shaped (batch, input_num)
# pop_nodes, pop_connections are population of genomes
#
# :argument batch_inputs: (batch_size, input_num)
# :argument N: int
# :argument input_idx: (input_num, )
# :argument output_idx: (output_num, )
# :argument pop_cal_seqs: (pop_size, N)
# :argument pop_nodes: (pop_size, N, 5)
# :argument pop_connections: (pop_size, 2, N, N)
#
# :return (pop_size, batch_size, output_num)
# """
# return forward_batch(batch_inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)