add debug mode for create_xx_functions for detail time cost analysis

This commit is contained in:
wls2002
2023-05-08 15:42:25 +08:00
parent d4a75b9394
commit e201d03157
8 changed files with 70 additions and 38 deletions

View File

@@ -86,6 +86,7 @@ def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
return vals[output_idx]
@partial(jit, static_argnames=['N'])
@partial(vmap, in_axes=(0, None, None, None, None, None, None))
def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
cal_seqs: Array, nodes: Array, connections: Array) -> Array:
@@ -106,6 +107,7 @@ def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Arr
return forward_single(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections)
@partial(jit, static_argnames=['N'])
@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
@@ -126,6 +128,7 @@ def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Arra
return forward_single(inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
@partial(jit, static_argnames=['N'])
@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
def pop_forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array: