move crossover_gene from ga.crossover to gene.basegene.

This commit is contained in:
wls2002
2024-05-30 15:06:08 +08:00
parent 9f6154d128
commit 20320105e6
4 changed files with 34 additions and 10 deletions

View File

@@ -23,6 +23,9 @@ class BaseGene:
def mutate(self, state, randkey, gene):
raise NotImplementedError
def crossover(self, state, randkey, gene1, gene2):
raise NotImplementedError
def distance(self, state, gene1, gene2):
raise NotImplementedError

View File

@@ -1,3 +1,4 @@
import jax, jax.numpy as jnp
from .. import BaseGene
@@ -8,5 +9,23 @@ class BaseConnGene(BaseGene):
def __init__(self):
super().__init__()
def crossover(self, state, randkey, gene1, gene2):
def crossover_attr():
return jnp.where(
jax.random.normal(randkey, gene1.shape) > 0,
gene1,
gene2,
)
return jax.lax.cond(
gene1[2] == gene2[2], # if both genes are enabled or disabled
crossover_attr, # then randomly pick attributes from gene1 or gene2
lambda: jnp.where( # one gene is enabled and the other is disabled
gene1[2], # if gene1 is enabled
gene1, # then return gene1
gene2, # else return gene2
),
)
def forward(self, state, attrs, inputs):
raise NotImplementedError

View File

@@ -1,3 +1,4 @@
import jax, jax.numpy as jnp
from .. import BaseGene
@@ -8,5 +9,12 @@ class BaseNodeGene(BaseGene):
def __init__(self):
super().__init__()
def crossover(self, state, randkey, gene1, gene2):
return jnp.where(
jax.random.normal(randkey, gene1.shape) > 0,
gene1,
gene2,
)
def forward(self, state, attrs, inputs, is_output_node=False):
raise NotImplementedError