add debug mode for create_xx_functions for detail time cost analysis
This commit is contained in:
@@ -86,6 +86,7 @@ def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||
return vals[output_idx]
|
||||
|
||||
|
||||
@partial(jit, static_argnames=['N'])
|
||||
@partial(vmap, in_axes=(0, None, None, None, None, None, None))
|
||||
def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||
cal_seqs: Array, nodes: Array, connections: Array) -> Array:
|
||||
@@ -106,6 +107,7 @@ def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Arr
|
||||
return forward_single(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections)
|
||||
|
||||
|
||||
@partial(jit, static_argnames=['N'])
|
||||
@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
|
||||
def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||
pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
|
||||
@@ -126,6 +128,7 @@ def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Arra
|
||||
return forward_single(inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
|
||||
|
||||
|
||||
@partial(jit, static_argnames=['N'])
|
||||
@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
|
||||
def pop_forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||
pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
|
||||
|
||||
Reference in New Issue
Block a user