complete fully stateful!
use black to format all files!
This commit is contained in:
@@ -5,5 +5,5 @@ class BaseCrossover:
|
||||
def setup(self, state=State()):
|
||||
return state
|
||||
|
||||
def __call__(self, state, genome, nodes1, nodes2, conns1, conns2):
|
||||
def __call__(self, state, randkey, genome, nodes1, nodes2, conns1, conns2):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -4,12 +4,12 @@ from .base import BaseCrossover
|
||||
|
||||
|
||||
class DefaultCrossover(BaseCrossover):
|
||||
def __call__(self, state, genome, nodes1, conns1, nodes2, conns2):
|
||||
def __call__(self, state, randkey, genome, nodes1, conns1, nodes2, conns2):
|
||||
"""
|
||||
use genome1 and genome2 to generate a new genome
|
||||
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||
"""
|
||||
randkey1, randkey2, randkey = jax.random.split(state.randkey, 3)
|
||||
randkey1, randkey2 = jax.random.split(randkey, 2)
|
||||
|
||||
# crossover nodes
|
||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||
@@ -34,11 +34,12 @@ class DefaultCrossover(BaseCrossover):
|
||||
self.crossover_gene(randkey2, conns1, conns2, is_conn=True),
|
||||
)
|
||||
|
||||
return state.update(randkey=randkey), new_nodes, new_conns
|
||||
return new_nodes, new_conns
|
||||
|
||||
def align_array(self, seq1, seq2, ar2, is_conn: bool):
|
||||
"""
|
||||
After I review this code, I found that it is the most difficult part of the code. Please never change it!
|
||||
After I review this code, I found that it is the most difficult part of the code.
|
||||
Please consider carefully before change it!
|
||||
make ar2 align with ar1.
|
||||
:param seq1:
|
||||
:param seq2:
|
||||
@@ -64,8 +65,8 @@ class DefaultCrossover(BaseCrossover):
|
||||
|
||||
return refactor_ar2
|
||||
|
||||
def crossover_gene(self, rand_key, g1, g2, is_conn):
|
||||
r = jax.random.uniform(rand_key, shape=g1.shape)
|
||||
def crossover_gene(self, randkey, g1, g2, is_conn):
|
||||
r = jax.random.uniform(randkey, shape=g1.shape)
|
||||
new_gene = jnp.where(r > 0.5, g1, g2)
|
||||
if is_conn: # fix enabled
|
||||
enabled = jnp.where(g1[:, 2] + g2[:, 2] > 0, 1, 0) # any of them is enabled
|
||||
|
||||
Reference in New Issue
Block a user