use black format all files;
remove "return state" for functions which will be executed in vmap; recover randkey as args in mutation methods
This commit is contained in:
@@ -2,7 +2,6 @@ from utils import State
|
||||
|
||||
|
||||
class BaseCrossover:
|
||||
|
||||
def setup(self, state=State()):
|
||||
return state
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from .base import BaseCrossover
|
||||
|
||||
|
||||
class DefaultCrossover(BaseCrossover):
|
||||
|
||||
def __call__(self, state, genome, nodes1, conns1, nodes2, conns2):
|
||||
"""
|
||||
use genome1 and genome2 to generate a new genome
|
||||
@@ -19,15 +18,21 @@ class DefaultCrossover(BaseCrossover):
|
||||
|
||||
# For not homologous genes, use the value of nodes1(winner)
|
||||
# For homologous genes, use the crossover result between nodes1 and nodes2
|
||||
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1,
|
||||
self.crossover_gene(randkey1, nodes1, nodes2, is_conn=False))
|
||||
new_nodes = jnp.where(
|
||||
jnp.isnan(nodes1) | jnp.isnan(nodes2),
|
||||
nodes1,
|
||||
self.crossover_gene(randkey1, nodes1, nodes2, is_conn=False),
|
||||
)
|
||||
|
||||
# crossover connections
|
||||
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
|
||||
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
|
||||
|
||||
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1,
|
||||
self.crossover_gene(randkey2, conns1, conns2, is_conn=True))
|
||||
new_conns = jnp.where(
|
||||
jnp.isnan(conns1) | jnp.isnan(conns2),
|
||||
conns1,
|
||||
self.crossover_gene(randkey2, conns1, conns2, is_conn=True),
|
||||
)
|
||||
|
||||
return state.update(randkey=randkey), new_nodes, new_conns
|
||||
|
||||
@@ -53,7 +58,9 @@ class DefaultCrossover(BaseCrossover):
|
||||
idx = jnp.arange(0, len(seq1))
|
||||
idx_fixed = jnp.dot(mask, idx)
|
||||
|
||||
refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan)
|
||||
refactor_ar2 = jnp.where(
|
||||
intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan
|
||||
)
|
||||
|
||||
return refactor_ar2
|
||||
|
||||
@@ -61,10 +68,6 @@ class DefaultCrossover(BaseCrossover):
|
||||
r = jax.random.uniform(rand_key, shape=g1.shape)
|
||||
new_gene = jnp.where(r > 0.5, g1, g2)
|
||||
if is_conn: # fix enabled
|
||||
enabled = jnp.where(
|
||||
g1[:, 2] + g2[:, 2] > 0, # any of them is enabled
|
||||
1,
|
||||
0
|
||||
)
|
||||
enabled = jnp.where(g1[:, 2] + g2[:, 2] > 0, 1, 0) # any of them is enabled
|
||||
new_gene = new_gene.at[:, 2].set(enabled)
|
||||
return new_gene
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
from .base import BaseMutation
|
||||
from .default import DefaultMutation
|
||||
from .default import DefaultMutation
|
||||
|
||||
@@ -2,7 +2,6 @@ from utils import State
|
||||
|
||||
|
||||
class BaseMutation:
|
||||
|
||||
def setup(self, state=State()):
|
||||
return state
|
||||
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from . import BaseMutation
|
||||
from utils import fetch_first, fetch_random, I_INT, unflatten_conns, check_cycles
|
||||
from utils import fetch_first, fetch_random, I_INF, unflatten_conns, check_cycles
|
||||
|
||||
|
||||
class DefaultMutation(BaseMutation):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conn_add: float = 0.4,
|
||||
conn_delete: float = 0,
|
||||
node_add: float = 0.2,
|
||||
node_delete: float = 0,
|
||||
self,
|
||||
conn_add: float = 0.4,
|
||||
conn_delete: float = 0,
|
||||
node_add: float = 0.2,
|
||||
node_delete: float = 0,
|
||||
):
|
||||
self.conn_add = conn_add
|
||||
self.conn_delete = conn_delete
|
||||
@@ -34,25 +33,45 @@ class DefaultMutation(BaseMutation):
|
||||
new_conns = conns_.at[idx, 2].set(False)
|
||||
|
||||
# add a new node
|
||||
new_nodes = genome.add_node(nodes_, new_node_key, genome.node_gene.new_custom_attrs())
|
||||
new_nodes = genome.add_node(
|
||||
nodes_, new_node_key, genome.node_gene.new_custom_attrs()
|
||||
)
|
||||
|
||||
# add two new connections
|
||||
new_conns = genome.add_conn(new_conns, i_key, new_node_key, True, genome.conn_gene.new_custom_attrs())
|
||||
new_conns = genome.add_conn(new_conns, new_node_key, o_key, True, genome.conn_gene.new_custom_attrs())
|
||||
new_conns = genome.add_conn(
|
||||
new_conns,
|
||||
i_key,
|
||||
new_node_key,
|
||||
True,
|
||||
genome.conn_gene.new_custom_attrs(),
|
||||
)
|
||||
new_conns = genome.add_conn(
|
||||
new_conns,
|
||||
new_node_key,
|
||||
o_key,
|
||||
True,
|
||||
genome.conn_gene.new_custom_attrs(),
|
||||
)
|
||||
|
||||
return new_nodes, new_conns
|
||||
|
||||
return jax.lax.cond(
|
||||
idx == I_INT,
|
||||
idx == I_INF,
|
||||
lambda: (nodes_, conns_), # do nothing
|
||||
successful_add_node
|
||||
successful_add_node,
|
||||
)
|
||||
|
||||
def mutate_delete_node(key_, nodes_, conns_):
|
||||
|
||||
# randomly choose a node
|
||||
key, idx = self.choice_node_key(key_, nodes_, genome.input_idx, genome.output_idx,
|
||||
allow_input_keys=False, allow_output_keys=False)
|
||||
key, idx = self.choice_node_key(
|
||||
key_,
|
||||
nodes_,
|
||||
genome.input_idx,
|
||||
genome.output_idx,
|
||||
allow_input_keys=False,
|
||||
allow_output_keys=False,
|
||||
)
|
||||
|
||||
def successful_delete_node():
|
||||
# delete the node
|
||||
@@ -62,15 +81,15 @@ class DefaultMutation(BaseMutation):
|
||||
new_conns = jnp.where(
|
||||
((conns_[:, 0] == key) | (conns_[:, 1] == key))[:, None],
|
||||
jnp.nan,
|
||||
conns_
|
||||
conns_,
|
||||
)
|
||||
|
||||
return new_nodes, new_conns
|
||||
|
||||
return jax.lax.cond(
|
||||
idx == I_INT,
|
||||
idx == I_INF,
|
||||
lambda: (nodes_, conns_), # do nothing
|
||||
successful_delete_node
|
||||
successful_delete_node,
|
||||
)
|
||||
|
||||
def mutate_add_conn(key_, nodes_, conns_):
|
||||
@@ -78,26 +97,40 @@ class DefaultMutation(BaseMutation):
|
||||
k1_, k2_ = jax.random.split(key_, num=2)
|
||||
|
||||
# input node of the connection can be any node
|
||||
i_key, from_idx = self.choice_node_key(k1_, nodes_, genome.input_idx, genome.output_idx,
|
||||
allow_input_keys=True, allow_output_keys=True)
|
||||
i_key, from_idx = self.choice_node_key(
|
||||
k1_,
|
||||
nodes_,
|
||||
genome.input_idx,
|
||||
genome.output_idx,
|
||||
allow_input_keys=True,
|
||||
allow_output_keys=True,
|
||||
)
|
||||
|
||||
# output node of the connection can be any node except input node
|
||||
o_key, to_idx = self.choice_node_key(k2_, nodes_, genome.input_idx, genome.output_idx,
|
||||
allow_input_keys=False, allow_output_keys=True)
|
||||
o_key, to_idx = self.choice_node_key(
|
||||
k2_,
|
||||
nodes_,
|
||||
genome.input_idx,
|
||||
genome.output_idx,
|
||||
allow_input_keys=False,
|
||||
allow_output_keys=True,
|
||||
)
|
||||
|
||||
conn_pos = fetch_first((conns_[:, 0] == i_key) & (conns_[:, 1] == o_key))
|
||||
is_already_exist = conn_pos != I_INT
|
||||
is_already_exist = conn_pos != I_INF
|
||||
|
||||
def nothing():
|
||||
return nodes_, conns_
|
||||
|
||||
def successful():
|
||||
return nodes_, genome.add_conn(conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs())
|
||||
return nodes_, genome.add_conn(
|
||||
conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs()
|
||||
)
|
||||
|
||||
def already_exist():
|
||||
return nodes_, conns_.at[conn_pos, 2].set(True)
|
||||
|
||||
if genome.network_type == 'feedforward':
|
||||
if genome.network_type == "feedforward":
|
||||
u_cons = unflatten_conns(nodes_, conns_)
|
||||
cons_exist = ~jnp.isnan(u_cons[0, :, :])
|
||||
is_cycle = check_cycles(nodes_, cons_exist, from_idx, to_idx)
|
||||
@@ -105,20 +138,11 @@ class DefaultMutation(BaseMutation):
|
||||
return jax.lax.cond(
|
||||
is_already_exist,
|
||||
already_exist,
|
||||
lambda:
|
||||
jax.lax.cond(
|
||||
is_cycle,
|
||||
nothing,
|
||||
successful
|
||||
)
|
||||
lambda: jax.lax.cond(is_cycle, nothing, successful),
|
||||
)
|
||||
|
||||
elif genome.network_type == 'recurrent':
|
||||
return jax.lax.cond(
|
||||
is_already_exist,
|
||||
already_exist,
|
||||
successful
|
||||
)
|
||||
elif genome.network_type == "recurrent":
|
||||
return jax.lax.cond(is_already_exist, already_exist, successful)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid network type: {genome.network_type}")
|
||||
@@ -131,9 +155,9 @@ class DefaultMutation(BaseMutation):
|
||||
return nodes_, genome.delete_conn_by_pos(conns_, idx)
|
||||
|
||||
return jax.lax.cond(
|
||||
idx == I_INT,
|
||||
idx == I_INF,
|
||||
lambda: (nodes_, conns_), # nothing
|
||||
successfully_delete_connection
|
||||
successfully_delete_connection,
|
||||
)
|
||||
|
||||
k1, k2, k3, k4 = jax.random.split(key, num=4)
|
||||
@@ -142,10 +166,18 @@ class DefaultMutation(BaseMutation):
|
||||
def no(key_, nodes_, conns_):
|
||||
return nodes_, conns_
|
||||
|
||||
nodes, conns = jax.lax.cond(r1 < self.node_add, mutate_add_node, no, k1, nodes, conns)
|
||||
nodes, conns = jax.lax.cond(r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns)
|
||||
nodes, conns = jax.lax.cond(r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns)
|
||||
nodes, conns = jax.lax.cond(r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns)
|
||||
nodes, conns = jax.lax.cond(
|
||||
r1 < self.node_add, mutate_add_node, no, k1, nodes, conns
|
||||
)
|
||||
nodes, conns = jax.lax.cond(
|
||||
r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns
|
||||
)
|
||||
nodes, conns = jax.lax.cond(
|
||||
r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns
|
||||
)
|
||||
nodes, conns = jax.lax.cond(
|
||||
r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns
|
||||
)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
@@ -163,8 +195,15 @@ class DefaultMutation(BaseMutation):
|
||||
|
||||
return new_nodes, new_conns
|
||||
|
||||
def choice_node_key(self, key, nodes, input_idx, output_idx,
|
||||
allow_input_keys: bool = False, allow_output_keys: bool = False):
|
||||
def choice_node_key(
|
||||
self,
|
||||
key,
|
||||
nodes,
|
||||
input_idx,
|
||||
output_idx,
|
||||
allow_input_keys: bool = False,
|
||||
allow_output_keys: bool = False,
|
||||
):
|
||||
"""
|
||||
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
|
||||
:param key:
|
||||
@@ -186,7 +225,7 @@ class DefaultMutation(BaseMutation):
|
||||
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_idx))
|
||||
|
||||
idx = fetch_random(key, mask)
|
||||
key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan)
|
||||
key = jnp.where(idx != I_INF, nodes[idx, 0], jnp.nan)
|
||||
return key, idx
|
||||
|
||||
def choice_connection_key(self, key, conns):
|
||||
@@ -196,7 +235,7 @@ class DefaultMutation(BaseMutation):
|
||||
"""
|
||||
|
||||
idx = fetch_random(key, ~jnp.isnan(conns[:, 0]))
|
||||
i_key = jnp.where(idx != I_INT, conns[idx, 0], jnp.nan)
|
||||
o_key = jnp.where(idx != I_INT, conns[idx, 1], jnp.nan)
|
||||
i_key = jnp.where(idx != I_INF, conns[idx, 0], jnp.nan)
|
||||
o_key = jnp.where(idx != I_INF, conns[idx, 1], jnp.nan)
|
||||
|
||||
return i_key, o_key, idx
|
||||
|
||||
Reference in New Issue
Block a user