use black format all files;
remove "return state" for functions which will be executed in vmap; recover randkey as args in mutation methods
This commit is contained in:
@@ -4,7 +4,6 @@ from .base import BaseCrossover
|
||||
|
||||
|
||||
class DefaultCrossover(BaseCrossover):
|
||||
|
||||
def __call__(self, state, genome, nodes1, conns1, nodes2, conns2):
|
||||
"""
|
||||
use genome1 and genome2 to generate a new genome
|
||||
@@ -19,15 +18,21 @@ class DefaultCrossover(BaseCrossover):
|
||||
|
||||
# For not homologous genes, use the value of nodes1(winner)
|
||||
# For homologous genes, use the crossover result between nodes1 and nodes2
|
||||
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1,
|
||||
self.crossover_gene(randkey1, nodes1, nodes2, is_conn=False))
|
||||
new_nodes = jnp.where(
|
||||
jnp.isnan(nodes1) | jnp.isnan(nodes2),
|
||||
nodes1,
|
||||
self.crossover_gene(randkey1, nodes1, nodes2, is_conn=False),
|
||||
)
|
||||
|
||||
# crossover connections
|
||||
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
|
||||
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
|
||||
|
||||
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1,
|
||||
self.crossover_gene(randkey2, conns1, conns2, is_conn=True))
|
||||
new_conns = jnp.where(
|
||||
jnp.isnan(conns1) | jnp.isnan(conns2),
|
||||
conns1,
|
||||
self.crossover_gene(randkey2, conns1, conns2, is_conn=True),
|
||||
)
|
||||
|
||||
return state.update(randkey=randkey), new_nodes, new_conns
|
||||
|
||||
@@ -53,7 +58,9 @@ class DefaultCrossover(BaseCrossover):
|
||||
idx = jnp.arange(0, len(seq1))
|
||||
idx_fixed = jnp.dot(mask, idx)
|
||||
|
||||
refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan)
|
||||
refactor_ar2 = jnp.where(
|
||||
intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan
|
||||
)
|
||||
|
||||
return refactor_ar2
|
||||
|
||||
@@ -61,10 +68,6 @@ class DefaultCrossover(BaseCrossover):
|
||||
r = jax.random.uniform(rand_key, shape=g1.shape)
|
||||
new_gene = jnp.where(r > 0.5, g1, g2)
|
||||
if is_conn: # fix enabled
|
||||
enabled = jnp.where(
|
||||
g1[:, 2] + g2[:, 2] > 0, # any of them is enabled
|
||||
1,
|
||||
0
|
||||
)
|
||||
enabled = jnp.where(g1[:, 2] + g2[:, 2] > 0, 1, 0) # any of them is enabled
|
||||
new_gene = new_gene.at[:, 2].set(enabled)
|
||||
return new_gene
|
||||
|
||||
Reference in New Issue
Block a user