All function with state will update the state and return it.
Remove randkey args in functions with state, since it can attach the randkey by states.
This commit is contained in:
@@ -3,8 +3,8 @@ from utils import State
|
||||
|
||||
class BaseCrossover:
|
||||
|
||||
def setup(self, key, state=State()):
|
||||
def setup(self, state=State()):
|
||||
return state
|
||||
|
||||
def __call__(self, state, key, genome, nodes1, nodes2, conns1, conns2):
|
||||
def __call__(self, state, genome, nodes1, nodes2, conns1, conns2):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -5,12 +5,12 @@ from .base import BaseCrossover
|
||||
|
||||
class DefaultCrossover(BaseCrossover):
|
||||
|
||||
def __call__(self, state, key, genome, nodes1, conns1, nodes2, conns2):
|
||||
def __call__(self, state, genome, nodes1, conns1, nodes2, conns2):
|
||||
"""
|
||||
use genome1 and genome2 to generate a new genome
|
||||
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||
"""
|
||||
randkey_1, randkey_2, key = jax.random.split(key, 3)
|
||||
randkey1, randkey2, randkey = jax.random.split(state.randkey, 3)
|
||||
|
||||
# crossover nodes
|
||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||
@@ -20,16 +20,16 @@ class DefaultCrossover(BaseCrossover):
|
||||
# For not homologous genes, use the value of nodes1(winner)
|
||||
# For homologous genes, use the crossover result between nodes1 and nodes2
|
||||
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1,
|
||||
self.crossover_gene(randkey_1, nodes1, nodes2, is_conn=False))
|
||||
self.crossover_gene(randkey1, nodes1, nodes2, is_conn=False))
|
||||
|
||||
# crossover connections
|
||||
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
|
||||
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
|
||||
|
||||
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1,
|
||||
self.crossover_gene(randkey_2, conns1, conns2, is_conn=True))
|
||||
self.crossover_gene(randkey2, conns1, conns2, is_conn=True))
|
||||
|
||||
return new_nodes, new_conns
|
||||
return state.update(randkey=randkey), new_nodes, new_conns
|
||||
|
||||
def align_array(self, seq1, seq2, ar2, is_conn: bool):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user