add "update_by_batch" in gene;

add flatten_conns as an inverse function for unflatten_conns;
add "test_flatten.ipynb" as test for them.
This commit is contained in:
wls2002
2024-05-30 19:44:52 +08:00
parent cd92f411dc
commit 5bd6e5c357
9 changed files with 481 additions and 11 deletions

View File

@@ -18,3 +18,12 @@ class BaseNodeGene(BaseGene):
def forward(self, state, attrs, inputs, is_output_node=False):
raise NotImplementedError
def update_by_batch(self, state, attrs, batch_inputs, is_output_node=False):
# default: do not update attrs, but to calculate batch_res
return (
jax.vmap(self.forward, in_axes=(None, None, 0, None))(
state, attrs, batch_inputs, is_output_node
),
attrs,
)