use black format all files;

remove "return state" for functions which will be executed in vmap;
recover randkey as args in mutation methods
This commit is contained in:
wls2002
2024-05-26 15:46:04 +08:00
parent 79d53ea7af
commit cf69b916af
38 changed files with 932 additions and 582 deletions

View File

@@ -2,7 +2,6 @@ from utils import State
class BaseCrossover:
def setup(self, state=State()):
return state

View File

@@ -4,7 +4,6 @@ from .base import BaseCrossover
class DefaultCrossover(BaseCrossover):
def __call__(self, state, genome, nodes1, conns1, nodes2, conns2):
"""
use genome1 and genome2 to generate a new genome
@@ -19,15 +18,21 @@ class DefaultCrossover(BaseCrossover):
# For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1,
self.crossover_gene(randkey1, nodes1, nodes2, is_conn=False))
new_nodes = jnp.where(
jnp.isnan(nodes1) | jnp.isnan(nodes2),
nodes1,
self.crossover_gene(randkey1, nodes1, nodes2, is_conn=False),
)
# crossover connections
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1,
self.crossover_gene(randkey2, conns1, conns2, is_conn=True))
new_conns = jnp.where(
jnp.isnan(conns1) | jnp.isnan(conns2),
conns1,
self.crossover_gene(randkey2, conns1, conns2, is_conn=True),
)
return state.update(randkey=randkey), new_nodes, new_conns
@@ -53,7 +58,9 @@ class DefaultCrossover(BaseCrossover):
idx = jnp.arange(0, len(seq1))
idx_fixed = jnp.dot(mask, idx)
refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan)
refactor_ar2 = jnp.where(
intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan
)
return refactor_ar2
@@ -61,10 +68,6 @@ class DefaultCrossover(BaseCrossover):
r = jax.random.uniform(rand_key, shape=g1.shape)
new_gene = jnp.where(r > 0.5, g1, g2)
if is_conn: # fix enabled
enabled = jnp.where(
g1[:, 2] + g2[:, 2] > 0, # any of them is enabled
1,
0
)
enabled = jnp.where(g1[:, 2] + g2[:, 2] > 0, 1, 0) # any of them is enabled
new_gene = new_gene.at[:, 2].set(enabled)
return new_gene

View File

@@ -1,2 +1,2 @@
from .base import BaseMutation
from .default import DefaultMutation
from .default import DefaultMutation

View File

@@ -2,7 +2,6 @@ from utils import State
class BaseMutation:
def setup(self, state=State()):
return state

View File

@@ -1,16 +1,15 @@
import jax, jax.numpy as jnp
from . import BaseMutation
from utils import fetch_first, fetch_random, I_INT, unflatten_conns, check_cycles
from utils import fetch_first, fetch_random, I_INF, unflatten_conns, check_cycles
class DefaultMutation(BaseMutation):
def __init__(
self,
conn_add: float = 0.4,
conn_delete: float = 0,
node_add: float = 0.2,
node_delete: float = 0,
self,
conn_add: float = 0.4,
conn_delete: float = 0,
node_add: float = 0.2,
node_delete: float = 0,
):
self.conn_add = conn_add
self.conn_delete = conn_delete
@@ -34,25 +33,45 @@ class DefaultMutation(BaseMutation):
new_conns = conns_.at[idx, 2].set(False)
# add a new node
new_nodes = genome.add_node(nodes_, new_node_key, genome.node_gene.new_custom_attrs())
new_nodes = genome.add_node(
nodes_, new_node_key, genome.node_gene.new_custom_attrs()
)
# add two new connections
new_conns = genome.add_conn(new_conns, i_key, new_node_key, True, genome.conn_gene.new_custom_attrs())
new_conns = genome.add_conn(new_conns, new_node_key, o_key, True, genome.conn_gene.new_custom_attrs())
new_conns = genome.add_conn(
new_conns,
i_key,
new_node_key,
True,
genome.conn_gene.new_custom_attrs(),
)
new_conns = genome.add_conn(
new_conns,
new_node_key,
o_key,
True,
genome.conn_gene.new_custom_attrs(),
)
return new_nodes, new_conns
return jax.lax.cond(
idx == I_INT,
idx == I_INF,
lambda: (nodes_, conns_), # do nothing
successful_add_node
successful_add_node,
)
def mutate_delete_node(key_, nodes_, conns_):
# randomly choose a node
key, idx = self.choice_node_key(key_, nodes_, genome.input_idx, genome.output_idx,
allow_input_keys=False, allow_output_keys=False)
key, idx = self.choice_node_key(
key_,
nodes_,
genome.input_idx,
genome.output_idx,
allow_input_keys=False,
allow_output_keys=False,
)
def successful_delete_node():
# delete the node
@@ -62,15 +81,15 @@ class DefaultMutation(BaseMutation):
new_conns = jnp.where(
((conns_[:, 0] == key) | (conns_[:, 1] == key))[:, None],
jnp.nan,
conns_
conns_,
)
return new_nodes, new_conns
return jax.lax.cond(
idx == I_INT,
idx == I_INF,
lambda: (nodes_, conns_), # do nothing
successful_delete_node
successful_delete_node,
)
def mutate_add_conn(key_, nodes_, conns_):
@@ -78,26 +97,40 @@ class DefaultMutation(BaseMutation):
k1_, k2_ = jax.random.split(key_, num=2)
# input node of the connection can be any node
i_key, from_idx = self.choice_node_key(k1_, nodes_, genome.input_idx, genome.output_idx,
allow_input_keys=True, allow_output_keys=True)
i_key, from_idx = self.choice_node_key(
k1_,
nodes_,
genome.input_idx,
genome.output_idx,
allow_input_keys=True,
allow_output_keys=True,
)
# output node of the connection can be any node except input node
o_key, to_idx = self.choice_node_key(k2_, nodes_, genome.input_idx, genome.output_idx,
allow_input_keys=False, allow_output_keys=True)
o_key, to_idx = self.choice_node_key(
k2_,
nodes_,
genome.input_idx,
genome.output_idx,
allow_input_keys=False,
allow_output_keys=True,
)
conn_pos = fetch_first((conns_[:, 0] == i_key) & (conns_[:, 1] == o_key))
is_already_exist = conn_pos != I_INT
is_already_exist = conn_pos != I_INF
def nothing():
return nodes_, conns_
def successful():
return nodes_, genome.add_conn(conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs())
return nodes_, genome.add_conn(
conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs()
)
def already_exist():
return nodes_, conns_.at[conn_pos, 2].set(True)
if genome.network_type == 'feedforward':
if genome.network_type == "feedforward":
u_cons = unflatten_conns(nodes_, conns_)
cons_exist = ~jnp.isnan(u_cons[0, :, :])
is_cycle = check_cycles(nodes_, cons_exist, from_idx, to_idx)
@@ -105,20 +138,11 @@ class DefaultMutation(BaseMutation):
return jax.lax.cond(
is_already_exist,
already_exist,
lambda:
jax.lax.cond(
is_cycle,
nothing,
successful
)
lambda: jax.lax.cond(is_cycle, nothing, successful),
)
elif genome.network_type == 'recurrent':
return jax.lax.cond(
is_already_exist,
already_exist,
successful
)
elif genome.network_type == "recurrent":
return jax.lax.cond(is_already_exist, already_exist, successful)
else:
raise ValueError(f"Invalid network type: {genome.network_type}")
@@ -131,9 +155,9 @@ class DefaultMutation(BaseMutation):
return nodes_, genome.delete_conn_by_pos(conns_, idx)
return jax.lax.cond(
idx == I_INT,
idx == I_INF,
lambda: (nodes_, conns_), # nothing
successfully_delete_connection
successfully_delete_connection,
)
k1, k2, k3, k4 = jax.random.split(key, num=4)
@@ -142,10 +166,18 @@ class DefaultMutation(BaseMutation):
def no(key_, nodes_, conns_):
return nodes_, conns_
nodes, conns = jax.lax.cond(r1 < self.node_add, mutate_add_node, no, k1, nodes, conns)
nodes, conns = jax.lax.cond(r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns)
nodes, conns = jax.lax.cond(r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns)
nodes, conns = jax.lax.cond(r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns)
nodes, conns = jax.lax.cond(
r1 < self.node_add, mutate_add_node, no, k1, nodes, conns
)
nodes, conns = jax.lax.cond(
r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns
)
nodes, conns = jax.lax.cond(
r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns
)
nodes, conns = jax.lax.cond(
r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns
)
return nodes, conns
@@ -163,8 +195,15 @@ class DefaultMutation(BaseMutation):
return new_nodes, new_conns
def choice_node_key(self, key, nodes, input_idx, output_idx,
allow_input_keys: bool = False, allow_output_keys: bool = False):
def choice_node_key(
self,
key,
nodes,
input_idx,
output_idx,
allow_input_keys: bool = False,
allow_output_keys: bool = False,
):
"""
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
:param key:
@@ -186,7 +225,7 @@ class DefaultMutation(BaseMutation):
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_idx))
idx = fetch_random(key, mask)
key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan)
key = jnp.where(idx != I_INF, nodes[idx, 0], jnp.nan)
return key, idx
def choice_connection_key(self, key, conns):
@@ -196,7 +235,7 @@ class DefaultMutation(BaseMutation):
"""
idx = fetch_random(key, ~jnp.isnan(conns[:, 0]))
i_key = jnp.where(idx != I_INT, conns[idx, 0], jnp.nan)
o_key = jnp.where(idx != I_INT, conns[idx, 1], jnp.nan)
i_key = jnp.where(idx != I_INF, conns[idx, 0], jnp.nan)
o_key = jnp.where(idx != I_INF, conns[idx, 1], jnp.nan)
return i_key, o_key, idx