use black format all files;

remove "return state" for functions which will be executed in vmap;
recover randkey as args in mutation methods
This commit is contained in:
wls2002
2024-05-26 15:46:04 +08:00
parent 79d53ea7af
commit cf69b916af
38 changed files with 932 additions and 582 deletions

View File

@@ -4,7 +4,6 @@ from .base import BaseCrossover
class DefaultCrossover(BaseCrossover):
def __call__(self, state, genome, nodes1, conns1, nodes2, conns2):
"""
use genome1 and genome2 to generate a new genome
@@ -19,15 +18,21 @@ class DefaultCrossover(BaseCrossover):
# For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1,
self.crossover_gene(randkey1, nodes1, nodes2, is_conn=False))
new_nodes = jnp.where(
jnp.isnan(nodes1) | jnp.isnan(nodes2),
nodes1,
self.crossover_gene(randkey1, nodes1, nodes2, is_conn=False),
)
# crossover connections
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1,
self.crossover_gene(randkey2, conns1, conns2, is_conn=True))
new_conns = jnp.where(
jnp.isnan(conns1) | jnp.isnan(conns2),
conns1,
self.crossover_gene(randkey2, conns1, conns2, is_conn=True),
)
return state.update(randkey=randkey), new_nodes, new_conns
@@ -53,7 +58,9 @@ class DefaultCrossover(BaseCrossover):
idx = jnp.arange(0, len(seq1))
idx_fixed = jnp.dot(mask, idx)
refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan)
refactor_ar2 = jnp.where(
intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan
)
return refactor_ar2
@@ -61,10 +68,6 @@ class DefaultCrossover(BaseCrossover):
r = jax.random.uniform(rand_key, shape=g1.shape)
new_gene = jnp.where(r > 0.5, g1, g2)
if is_conn: # fix enabled
enabled = jnp.where(
g1[:, 2] + g2[:, 2] > 0, # any of them is enabled
1,
0
)
enabled = jnp.where(g1[:, 2] + g2[:, 2] > 0, 1, 0) # any of them is enabled
new_gene = new_gene.at[:, 2].set(enabled)
return new_gene