This commit is contained in:
wls2002
2023-05-13 20:58:03 +08:00
parent 90a9cc322d
commit 72c9d4167a
10 changed files with 372 additions and 529 deletions

View File

@@ -11,7 +11,8 @@ from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_
from .graph import check_cycles
@partial(jit, static_argnames=('single_structure_mutate',))
# TODO: Temporally delete single_structural_mutation, for i need to run it as soon as possible.
@jit
def mutate(rand_key: Array,
nodes: Array,
connections: Array,
@@ -44,7 +45,7 @@ def mutate(rand_key: Array,
delete_node_rate: float = 0.2,
add_connection_rate: float = 0.4,
delete_connection_rate: float = 0.4,
single_structure_mutate: bool = True):
):
"""
:param output_idx:
:param input_idx:
@@ -78,65 +79,26 @@ def mutate(rand_key: Array,
:param delete_node_rate:
:param add_connection_rate:
:param delete_connection_rate:
:param single_structure_mutate: a genome is structurally mutate at most once
:return:
"""
# mutate_structure
def nothing(rk, n, c):
return n, c
def m_add_node(rk, n, c):
return mutate_add_node(rk, n, c, new_node_key, bias_mean, response_mean, act_default, agg_default)
def m_delete_node(rk, n, c):
return mutate_delete_node(rk, n, c, input_idx, output_idx)
def m_add_connection(rk, n, c):
return mutate_add_connection(rk, n, c, input_idx, output_idx)
def m_delete_connection(rk, n, c):
return mutate_delete_connection(rk, n, c)
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
mutate_structure_li = [nothing, m_add_node, m_delete_node, m_add_connection, m_delete_connection]
# mutate add node
aux_nodes, aux_connections = m_add_node(r1, nodes, connections)
nodes = jnp.where(rand(r1) < add_node_rate, aux_nodes, nodes)
connections = jnp.where(rand(r1) < add_node_rate, aux_connections, connections)
if single_structure_mutate:
r1, r2, rand_key = jax.random.split(rand_key, 3)
d = jnp.maximum(1, add_node_rate + delete_node_rate + add_connection_rate + delete_connection_rate)
# shorten variable names for beauty
anr, dnr = add_node_rate / d, delete_node_rate / d
acr, dcr = add_connection_rate / d, delete_connection_rate / d
r = rand(r1)
branch = 0
branch = jnp.where(r <= anr, 1, branch)
branch = jnp.where((anr < r) & (r <= anr + dnr), 2, branch)
branch = jnp.where((anr + dnr < r) & (r <= anr + dnr + acr), 3, branch)
branch = jnp.where((anr + dnr + acr) < r & r <= (anr + dnr + acr + dcr), 4, branch)
nodes, connections = jax.lax.switch(branch, mutate_structure_li, (r2, nodes, connections))
else:
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
# mutate add node
aux_nodes, aux_connections = m_add_node(r1, nodes, connections)
nodes = jnp.where(rand(r1) < add_node_rate, aux_nodes, nodes)
connections = jnp.where(rand(r1) < add_node_rate, aux_connections, connections)
# mutate delete node
aux_nodes, aux_connections = m_delete_node(r2, nodes, connections)
nodes = jnp.where(rand(r2) < delete_node_rate, aux_nodes, nodes)
connections = jnp.where(rand(r2) < delete_node_rate, aux_connections, connections)
# mutate add connection
aux_nodes, aux_connections = m_add_connection(r3, nodes, connections)
nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes)
connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections)
# mutate delete connection
aux_nodes, aux_connections = m_delete_connection(r4, nodes, connections)
nodes = jnp.where(rand(r4) < delete_connection_rate, aux_nodes, nodes)
connections = jnp.where(rand(r4) < delete_connection_rate, aux_connections, connections)
# mutate add connection
aux_nodes, aux_connections = m_add_connection(r3, nodes, connections)
nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes)
connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections)
nodes, connections = mutate_values(rand_key, nodes, connections, bias_mean, bias_std, bias_mutate_strength,
bias_mutate_rate, bias_replace_rate, response_mean, response_std,
@@ -379,9 +341,9 @@ def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array,
# randomly choose two nodes
k1, k2 = jax.random.split(rand_key, num=2)
i_key, from_idx = choice_node_key(k1, nodes, input_keys, output_keys,
allow_input_keys=True, allow_output_keys=True)
allow_input_keys=True, allow_output_keys=True)
o_key, to_idx = choice_node_key(k2, nodes, input_keys, output_keys,
allow_input_keys=False, allow_output_keys=True)
allow_input_keys=False, allow_output_keys=True)
con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))