move crossover_gene from ga.crossover to gene.basegene.
This commit is contained in:
@@ -23,6 +23,9 @@ class BaseGene:
|
||||
def mutate(self, state, randkey, gene):
|
||||
raise NotImplementedError
|
||||
|
||||
def crossover(self, state, randkey, gene1, gene2):
|
||||
raise NotImplementedError
|
||||
|
||||
def distance(self, state, gene1, gene2):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from .. import BaseGene
|
||||
|
||||
|
||||
@@ -8,5 +9,23 @@ class BaseConnGene(BaseGene):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def crossover(self, state, randkey, gene1, gene2):
|
||||
def crossover_attr():
|
||||
return jnp.where(
|
||||
jax.random.normal(randkey, gene1.shape) > 0,
|
||||
gene1,
|
||||
gene2,
|
||||
)
|
||||
|
||||
return jax.lax.cond(
|
||||
gene1[2] == gene2[2], # if both genes are enabled or disabled
|
||||
crossover_attr, # then randomly pick attributes from gene1 or gene2
|
||||
lambda: jnp.where( # one gene is enabled and the other is disabled
|
||||
gene1[2], # if gene1 is enabled
|
||||
gene1, # then return gene1
|
||||
gene2, # else return gene2
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, state, attrs, inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from .. import BaseGene
|
||||
|
||||
|
||||
@@ -8,5 +9,12 @@ class BaseNodeGene(BaseGene):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def crossover(self, state, randkey, gene1, gene2):
|
||||
return jnp.where(
|
||||
jax.random.normal(randkey, gene1.shape) > 0,
|
||||
gene1,
|
||||
gene2,
|
||||
)
|
||||
|
||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user