Files
tensorneat-mend/tensorneat/algorithm/neat/gene/node/base.py
wls2002 5bd6e5c357 add "update_by_batch" in gene;
add flatten_conns as an inverse function for unflatten_conns;
add "test_flatten.ipynb" as test for them.
2024-05-30 19:44:52 +08:00

30 lines
823 B
Python

import jax, jax.numpy as jnp
from .. import BaseGene
class BaseNodeGene(BaseGene):
"Base class for node genes."
fixed_attrs = ["index"]
def __init__(self):
super().__init__()
def crossover(self, state, randkey, gene1, gene2):
return jnp.where(
jax.random.normal(randkey, gene1.shape) > 0,
gene1,
gene2,
)
def forward(self, state, attrs, inputs, is_output_node=False):
raise NotImplementedError
def update_by_batch(self, state, attrs, batch_inputs, is_output_node=False):
# default: do not update attrs, but to calculate batch_res
return (
jax.vmap(self.forward, in_axes=(None, None, 0, None))(
state, attrs, batch_inputs, is_output_node
),
attrs,
)