move crossover_gene from ga.crossover to gene.basegene.
This commit is contained in:
@@ -10,6 +10,8 @@ class DefaultCrossover(BaseCrossover):
|
|||||||
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||||
"""
|
"""
|
||||||
randkey1, randkey2 = jax.random.split(randkey, 2)
|
randkey1, randkey2 = jax.random.split(randkey, 2)
|
||||||
|
randkeys1 = jax.random.split(randkey1, genome.max_nodes)
|
||||||
|
randkeys2 = jax.random.split(randkey2, genome.max_conns)
|
||||||
|
|
||||||
# crossover nodes
|
# crossover nodes
|
||||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||||
@@ -21,7 +23,7 @@ class DefaultCrossover(BaseCrossover):
|
|||||||
new_nodes = jnp.where(
|
new_nodes = jnp.where(
|
||||||
jnp.isnan(nodes1) | jnp.isnan(nodes2),
|
jnp.isnan(nodes1) | jnp.isnan(nodes2),
|
||||||
nodes1,
|
nodes1,
|
||||||
self.crossover_gene(randkey1, nodes1, nodes2, is_conn=False),
|
jax.vmap(genome.node_gene.crossover, in_axes=(None, 0, 0, 0))(state, randkeys1, nodes1, nodes2),
|
||||||
)
|
)
|
||||||
|
|
||||||
# crossover connections
|
# crossover connections
|
||||||
@@ -31,7 +33,7 @@ class DefaultCrossover(BaseCrossover):
|
|||||||
new_conns = jnp.where(
|
new_conns = jnp.where(
|
||||||
jnp.isnan(conns1) | jnp.isnan(conns2),
|
jnp.isnan(conns1) | jnp.isnan(conns2),
|
||||||
conns1,
|
conns1,
|
||||||
self.crossover_gene(randkey2, conns1, conns2, is_conn=True),
|
jax.vmap(genome.conn_gene.crossover, in_axes=(None, 0, 0, 0))(state, randkeys2, conns1, conns2),
|
||||||
)
|
)
|
||||||
|
|
||||||
return new_nodes, new_conns
|
return new_nodes, new_conns
|
||||||
@@ -64,11 +66,3 @@ class DefaultCrossover(BaseCrossover):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return refactor_ar2
|
return refactor_ar2
|
||||||
|
|
||||||
def crossover_gene(self, randkey, g1, g2, is_conn):
|
|
||||||
r = jax.random.uniform(randkey, shape=g1.shape)
|
|
||||||
new_gene = jnp.where(r > 0.5, g1, g2)
|
|
||||||
if is_conn: # fix enabled
|
|
||||||
enabled = jnp.where(g1[:, 2] + g2[:, 2] > 0, 1, 0) # any of them is enabled
|
|
||||||
new_gene = new_gene.at[:, 2].set(enabled)
|
|
||||||
return new_gene
|
|
||||||
|
|||||||
@@ -23,6 +23,9 @@ class BaseGene:
|
|||||||
def mutate(self, state, randkey, gene):
|
def mutate(self, state, randkey, gene):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def crossover(self, state, randkey, gene1, gene2):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def distance(self, state, gene1, gene2):
|
def distance(self, state, gene1, gene2):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import jax, jax.numpy as jnp
|
||||||
from .. import BaseGene
|
from .. import BaseGene
|
||||||
|
|
||||||
|
|
||||||
@@ -8,5 +9,23 @@ class BaseConnGene(BaseGene):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def crossover(self, state, randkey, gene1, gene2):
|
||||||
|
def crossover_attr():
|
||||||
|
return jnp.where(
|
||||||
|
jax.random.normal(randkey, gene1.shape) > 0,
|
||||||
|
gene1,
|
||||||
|
gene2,
|
||||||
|
)
|
||||||
|
|
||||||
|
return jax.lax.cond(
|
||||||
|
gene1[2] == gene2[2], # if both genes are enabled or disabled
|
||||||
|
crossover_attr, # then randomly pick attributes from gene1 or gene2
|
||||||
|
lambda: jnp.where( # one gene is enabled and the other is disabled
|
||||||
|
gene1[2], # if gene1 is enabled
|
||||||
|
gene1, # then return gene1
|
||||||
|
gene2, # else return gene2
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, state, attrs, inputs):
|
def forward(self, state, attrs, inputs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import jax, jax.numpy as jnp
|
||||||
from .. import BaseGene
|
from .. import BaseGene
|
||||||
|
|
||||||
|
|
||||||
@@ -8,5 +9,12 @@ class BaseNodeGene(BaseGene):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def crossover(self, state, randkey, gene1, gene2):
|
||||||
|
return jnp.where(
|
||||||
|
jax.random.normal(randkey, gene1.shape) > 0,
|
||||||
|
gene1,
|
||||||
|
gene2,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
Reference in New Issue
Block a user