move crossover_gene from ga.crossover to gene.basegene.

This commit is contained in:
wls2002
2024-05-30 15:06:08 +08:00
parent 9f6154d128
commit 20320105e6
4 changed files with 34 additions and 10 deletions

View File

@@ -1,3 +1,4 @@
import jax, jax.numpy as jnp
from .. import BaseGene
@@ -8,5 +9,12 @@ class BaseNodeGene(BaseGene):
def __init__(self):
super().__init__()
def crossover(self, state, randkey, gene1, gene2):
return jnp.where(
jax.random.normal(randkey, gene1.shape) > 0,
gene1,
gene2,
)
def forward(self, state, attrs, inputs, is_output_node=False):
raise NotImplementedError