203 lines
7.5 KiB
Python
203 lines
7.5 KiB
Python
import jax, jax.numpy as jnp
|
|
from . import BaseMutation
|
|
from utils import fetch_first, fetch_random, I_INT, unflatten_conns, check_cycles
|
|
|
|
|
|
class DefaultMutation(BaseMutation):
|
|
|
|
def __init__(
|
|
self,
|
|
conn_add: float = 0.4,
|
|
conn_delete: float = 0,
|
|
node_add: float = 0.2,
|
|
node_delete: float = 0,
|
|
):
|
|
self.conn_add = conn_add
|
|
self.conn_delete = conn_delete
|
|
self.node_add = node_add
|
|
self.node_delete = node_delete
|
|
|
|
def __call__(self, state, genome, nodes, conns, new_node_key):
|
|
k1, k2, randkey = jax.random.split(state.randkey)
|
|
|
|
nodes, conns = self.mutate_structure(k1, genome, nodes, conns, new_node_key)
|
|
nodes, conns = self.mutate_values(k2, genome, nodes, conns)
|
|
|
|
return state.update(randkey=randkey), nodes, conns
|
|
|
|
def mutate_structure(self, key, genome, nodes, conns, new_node_key):
|
|
def mutate_add_node(key_, nodes_, conns_):
|
|
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
|
|
|
|
def successful_add_node():
|
|
# disable the connection
|
|
new_conns = conns_.at[idx, 2].set(False)
|
|
|
|
# add a new node
|
|
new_nodes = genome.add_node(nodes_, new_node_key, genome.node_gene.new_custom_attrs())
|
|
|
|
# add two new connections
|
|
new_conns = genome.add_conn(new_conns, i_key, new_node_key, True, genome.conn_gene.new_custom_attrs())
|
|
new_conns = genome.add_conn(new_conns, new_node_key, o_key, True, genome.conn_gene.new_custom_attrs())
|
|
|
|
return new_nodes, new_conns
|
|
|
|
return jax.lax.cond(
|
|
idx == I_INT,
|
|
lambda: (nodes_, conns_), # do nothing
|
|
successful_add_node
|
|
)
|
|
|
|
def mutate_delete_node(key_, nodes_, conns_):
|
|
|
|
# randomly choose a node
|
|
key, idx = self.choice_node_key(key_, nodes_, genome.input_idx, genome.output_idx,
|
|
allow_input_keys=False, allow_output_keys=False)
|
|
|
|
def successful_delete_node():
|
|
# delete the node
|
|
new_nodes = genome.delete_node_by_pos(nodes_, idx)
|
|
|
|
# delete all connections
|
|
new_conns = jnp.where(
|
|
((conns_[:, 0] == key) | (conns_[:, 1] == key))[:, None],
|
|
jnp.nan,
|
|
conns_
|
|
)
|
|
|
|
return new_nodes, new_conns
|
|
|
|
return jax.lax.cond(
|
|
idx == I_INT,
|
|
lambda: (nodes_, conns_), # do nothing
|
|
successful_delete_node
|
|
)
|
|
|
|
def mutate_add_conn(key_, nodes_, conns_):
|
|
# randomly choose two nodes
|
|
k1_, k2_ = jax.random.split(key_, num=2)
|
|
|
|
# input node of the connection can be any node
|
|
i_key, from_idx = self.choice_node_key(k1_, nodes_, genome.input_idx, genome.output_idx,
|
|
allow_input_keys=True, allow_output_keys=True)
|
|
|
|
# output node of the connection can be any node except input node
|
|
o_key, to_idx = self.choice_node_key(k2_, nodes_, genome.input_idx, genome.output_idx,
|
|
allow_input_keys=False, allow_output_keys=True)
|
|
|
|
conn_pos = fetch_first((conns_[:, 0] == i_key) & (conns_[:, 1] == o_key))
|
|
is_already_exist = conn_pos != I_INT
|
|
|
|
def nothing():
|
|
return nodes_, conns_
|
|
|
|
def successful():
|
|
return nodes_, genome.add_conn(conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs())
|
|
|
|
def already_exist():
|
|
return nodes_, conns_.at[conn_pos, 2].set(True)
|
|
|
|
if genome.network_type == 'feedforward':
|
|
u_cons = unflatten_conns(nodes_, conns_)
|
|
cons_exist = ~jnp.isnan(u_cons[0, :, :])
|
|
is_cycle = check_cycles(nodes_, cons_exist, from_idx, to_idx)
|
|
|
|
return jax.lax.cond(
|
|
is_already_exist,
|
|
already_exist,
|
|
lambda:
|
|
jax.lax.cond(
|
|
is_cycle,
|
|
nothing,
|
|
successful
|
|
)
|
|
)
|
|
|
|
elif genome.network_type == 'recurrent':
|
|
return jax.lax.cond(
|
|
is_already_exist,
|
|
already_exist,
|
|
successful
|
|
)
|
|
|
|
else:
|
|
raise ValueError(f"Invalid network type: {genome.network_type}")
|
|
|
|
def mutate_delete_conn(key_, nodes_, conns_):
|
|
# randomly choose a connection
|
|
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
|
|
|
|
def successfully_delete_connection():
|
|
return nodes_, genome.delete_conn_by_pos(conns_, idx)
|
|
|
|
return jax.lax.cond(
|
|
idx == I_INT,
|
|
lambda: (nodes_, conns_), # nothing
|
|
successfully_delete_connection
|
|
)
|
|
|
|
k1, k2, k3, k4 = jax.random.split(key, num=4)
|
|
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
|
|
|
|
def no(key_, nodes_, conns_):
|
|
return nodes_, conns_
|
|
|
|
nodes, conns = jax.lax.cond(r1 < self.node_add, mutate_add_node, no, k1, nodes, conns)
|
|
nodes, conns = jax.lax.cond(r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns)
|
|
nodes, conns = jax.lax.cond(r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns)
|
|
nodes, conns = jax.lax.cond(r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns)
|
|
|
|
return nodes, conns
|
|
|
|
def mutate_values(self, key, genome, nodes, conns):
|
|
k1, k2 = jax.random.split(key, num=2)
|
|
nodes_keys = jax.random.split(k1, num=nodes.shape[0])
|
|
conns_keys = jax.random.split(k2, num=conns.shape[0])
|
|
|
|
new_nodes = jax.vmap(genome.node_gene.mutate)(nodes_keys, nodes)
|
|
new_conns = jax.vmap(genome.conn_gene.mutate)(conns_keys, conns)
|
|
|
|
# nan nodes not changed
|
|
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)
|
|
new_conns = jnp.where(jnp.isnan(conns), jnp.nan, new_conns)
|
|
|
|
return new_nodes, new_conns
|
|
|
|
def choice_node_key(self, key, nodes, input_idx, output_idx,
|
|
allow_input_keys: bool = False, allow_output_keys: bool = False):
|
|
"""
|
|
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
|
|
:param key:
|
|
:param nodes:
|
|
:param input_idx:
|
|
:param output_idx:
|
|
:param allow_input_keys:
|
|
:param allow_output_keys:
|
|
:return: return its key and position(idx)
|
|
"""
|
|
|
|
node_keys = nodes[:, 0]
|
|
mask = ~jnp.isnan(node_keys)
|
|
|
|
if not allow_input_keys:
|
|
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_idx))
|
|
|
|
if not allow_output_keys:
|
|
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_idx))
|
|
|
|
idx = fetch_random(key, mask)
|
|
key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan)
|
|
return key, idx
|
|
|
|
def choice_connection_key(self, key, conns):
|
|
"""
|
|
Randomly choose a connection key from the given connections.
|
|
:return: i_key, o_key, idx
|
|
"""
|
|
|
|
idx = fetch_random(key, ~jnp.isnan(conns[:, 0]))
|
|
i_key = jnp.where(idx != I_INT, conns[idx, 0], jnp.nan)
|
|
o_key = jnp.where(idx != I_INT, conns[idx, 1], jnp.nan)
|
|
|
|
return i_key, o_key, idx
|