use black format all files;

remove "return state" for functions which will be executed in vmap;
recover randkey as args in mutation methods
This commit is contained in:
wls2002
2024-05-26 15:46:04 +08:00
parent 79d53ea7af
commit cf69b916af
38 changed files with 932 additions and 582 deletions

View File

@@ -1,16 +1,15 @@
import jax, jax.numpy as jnp
from . import BaseMutation
from utils import fetch_first, fetch_random, I_INT, unflatten_conns, check_cycles
from utils import fetch_first, fetch_random, I_INF, unflatten_conns, check_cycles
class DefaultMutation(BaseMutation):
def __init__(
self,
conn_add: float = 0.4,
conn_delete: float = 0,
node_add: float = 0.2,
node_delete: float = 0,
self,
conn_add: float = 0.4,
conn_delete: float = 0,
node_add: float = 0.2,
node_delete: float = 0,
):
self.conn_add = conn_add
self.conn_delete = conn_delete
@@ -34,25 +33,45 @@ class DefaultMutation(BaseMutation):
new_conns = conns_.at[idx, 2].set(False)
# add a new node
new_nodes = genome.add_node(nodes_, new_node_key, genome.node_gene.new_custom_attrs())
new_nodes = genome.add_node(
nodes_, new_node_key, genome.node_gene.new_custom_attrs()
)
# add two new connections
new_conns = genome.add_conn(new_conns, i_key, new_node_key, True, genome.conn_gene.new_custom_attrs())
new_conns = genome.add_conn(new_conns, new_node_key, o_key, True, genome.conn_gene.new_custom_attrs())
new_conns = genome.add_conn(
new_conns,
i_key,
new_node_key,
True,
genome.conn_gene.new_custom_attrs(),
)
new_conns = genome.add_conn(
new_conns,
new_node_key,
o_key,
True,
genome.conn_gene.new_custom_attrs(),
)
return new_nodes, new_conns
return jax.lax.cond(
idx == I_INT,
idx == I_INF,
lambda: (nodes_, conns_), # do nothing
successful_add_node
successful_add_node,
)
def mutate_delete_node(key_, nodes_, conns_):
# randomly choose a node
key, idx = self.choice_node_key(key_, nodes_, genome.input_idx, genome.output_idx,
allow_input_keys=False, allow_output_keys=False)
key, idx = self.choice_node_key(
key_,
nodes_,
genome.input_idx,
genome.output_idx,
allow_input_keys=False,
allow_output_keys=False,
)
def successful_delete_node():
# delete the node
@@ -62,15 +81,15 @@ class DefaultMutation(BaseMutation):
new_conns = jnp.where(
((conns_[:, 0] == key) | (conns_[:, 1] == key))[:, None],
jnp.nan,
conns_
conns_,
)
return new_nodes, new_conns
return jax.lax.cond(
idx == I_INT,
idx == I_INF,
lambda: (nodes_, conns_), # do nothing
successful_delete_node
successful_delete_node,
)
def mutate_add_conn(key_, nodes_, conns_):
@@ -78,26 +97,40 @@ class DefaultMutation(BaseMutation):
k1_, k2_ = jax.random.split(key_, num=2)
# input node of the connection can be any node
i_key, from_idx = self.choice_node_key(k1_, nodes_, genome.input_idx, genome.output_idx,
allow_input_keys=True, allow_output_keys=True)
i_key, from_idx = self.choice_node_key(
k1_,
nodes_,
genome.input_idx,
genome.output_idx,
allow_input_keys=True,
allow_output_keys=True,
)
# output node of the connection can be any node except input node
o_key, to_idx = self.choice_node_key(k2_, nodes_, genome.input_idx, genome.output_idx,
allow_input_keys=False, allow_output_keys=True)
o_key, to_idx = self.choice_node_key(
k2_,
nodes_,
genome.input_idx,
genome.output_idx,
allow_input_keys=False,
allow_output_keys=True,
)
conn_pos = fetch_first((conns_[:, 0] == i_key) & (conns_[:, 1] == o_key))
is_already_exist = conn_pos != I_INT
is_already_exist = conn_pos != I_INF
def nothing():
return nodes_, conns_
def successful():
return nodes_, genome.add_conn(conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs())
return nodes_, genome.add_conn(
conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs()
)
def already_exist():
return nodes_, conns_.at[conn_pos, 2].set(True)
if genome.network_type == 'feedforward':
if genome.network_type == "feedforward":
u_cons = unflatten_conns(nodes_, conns_)
cons_exist = ~jnp.isnan(u_cons[0, :, :])
is_cycle = check_cycles(nodes_, cons_exist, from_idx, to_idx)
@@ -105,20 +138,11 @@ class DefaultMutation(BaseMutation):
return jax.lax.cond(
is_already_exist,
already_exist,
lambda:
jax.lax.cond(
is_cycle,
nothing,
successful
)
lambda: jax.lax.cond(is_cycle, nothing, successful),
)
elif genome.network_type == 'recurrent':
return jax.lax.cond(
is_already_exist,
already_exist,
successful
)
elif genome.network_type == "recurrent":
return jax.lax.cond(is_already_exist, already_exist, successful)
else:
raise ValueError(f"Invalid network type: {genome.network_type}")
@@ -131,9 +155,9 @@ class DefaultMutation(BaseMutation):
return nodes_, genome.delete_conn_by_pos(conns_, idx)
return jax.lax.cond(
idx == I_INT,
idx == I_INF,
lambda: (nodes_, conns_), # nothing
successfully_delete_connection
successfully_delete_connection,
)
k1, k2, k3, k4 = jax.random.split(key, num=4)
@@ -142,10 +166,18 @@ class DefaultMutation(BaseMutation):
def no(key_, nodes_, conns_):
return nodes_, conns_
nodes, conns = jax.lax.cond(r1 < self.node_add, mutate_add_node, no, k1, nodes, conns)
nodes, conns = jax.lax.cond(r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns)
nodes, conns = jax.lax.cond(r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns)
nodes, conns = jax.lax.cond(r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns)
nodes, conns = jax.lax.cond(
r1 < self.node_add, mutate_add_node, no, k1, nodes, conns
)
nodes, conns = jax.lax.cond(
r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns
)
nodes, conns = jax.lax.cond(
r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns
)
nodes, conns = jax.lax.cond(
r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns
)
return nodes, conns
@@ -163,8 +195,15 @@ class DefaultMutation(BaseMutation):
return new_nodes, new_conns
def choice_node_key(self, key, nodes, input_idx, output_idx,
allow_input_keys: bool = False, allow_output_keys: bool = False):
def choice_node_key(
self,
key,
nodes,
input_idx,
output_idx,
allow_input_keys: bool = False,
allow_output_keys: bool = False,
):
"""
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
:param key:
@@ -186,7 +225,7 @@ class DefaultMutation(BaseMutation):
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_idx))
idx = fetch_random(key, mask)
key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan)
key = jnp.where(idx != I_INF, nodes[idx, 0], jnp.nan)
return key, idx
def choice_connection_key(self, key, conns):
@@ -196,7 +235,7 @@ class DefaultMutation(BaseMutation):
"""
idx = fetch_random(key, ~jnp.isnan(conns[:, 0]))
i_key = jnp.where(idx != I_INT, conns[idx, 0], jnp.nan)
o_key = jnp.where(idx != I_INT, conns[idx, 1], jnp.nan)
i_key = jnp.where(idx != I_INF, conns[idx, 0], jnp.nan)
o_key = jnp.where(idx != I_INF, conns[idx, 1], jnp.nan)
return i_key, o_key, idx