FAST!
This commit is contained in:
@@ -11,7 +11,8 @@ from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_
|
||||
from .graph import check_cycles
|
||||
|
||||
|
||||
@partial(jit, static_argnames=('single_structure_mutate',))
|
||||
# TODO: Temporally delete single_structural_mutation, for i need to run it as soon as possible.
|
||||
@jit
|
||||
def mutate(rand_key: Array,
|
||||
nodes: Array,
|
||||
connections: Array,
|
||||
@@ -44,7 +45,7 @@ def mutate(rand_key: Array,
|
||||
delete_node_rate: float = 0.2,
|
||||
add_connection_rate: float = 0.4,
|
||||
delete_connection_rate: float = 0.4,
|
||||
single_structure_mutate: bool = True):
|
||||
):
|
||||
"""
|
||||
:param output_idx:
|
||||
:param input_idx:
|
||||
@@ -78,65 +79,26 @@ def mutate(rand_key: Array,
|
||||
:param delete_node_rate:
|
||||
:param add_connection_rate:
|
||||
:param delete_connection_rate:
|
||||
:param single_structure_mutate: a genome is structurally mutate at most once
|
||||
:return:
|
||||
"""
|
||||
|
||||
# mutate_structure
|
||||
def nothing(rk, n, c):
|
||||
return n, c
|
||||
|
||||
def m_add_node(rk, n, c):
|
||||
return mutate_add_node(rk, n, c, new_node_key, bias_mean, response_mean, act_default, agg_default)
|
||||
|
||||
def m_delete_node(rk, n, c):
|
||||
return mutate_delete_node(rk, n, c, input_idx, output_idx)
|
||||
|
||||
def m_add_connection(rk, n, c):
|
||||
return mutate_add_connection(rk, n, c, input_idx, output_idx)
|
||||
|
||||
def m_delete_connection(rk, n, c):
|
||||
return mutate_delete_connection(rk, n, c)
|
||||
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
|
||||
|
||||
mutate_structure_li = [nothing, m_add_node, m_delete_node, m_add_connection, m_delete_connection]
|
||||
# mutate add node
|
||||
aux_nodes, aux_connections = m_add_node(r1, nodes, connections)
|
||||
nodes = jnp.where(rand(r1) < add_node_rate, aux_nodes, nodes)
|
||||
connections = jnp.where(rand(r1) < add_node_rate, aux_connections, connections)
|
||||
|
||||
if single_structure_mutate:
|
||||
r1, r2, rand_key = jax.random.split(rand_key, 3)
|
||||
d = jnp.maximum(1, add_node_rate + delete_node_rate + add_connection_rate + delete_connection_rate)
|
||||
|
||||
# shorten variable names for beauty
|
||||
anr, dnr = add_node_rate / d, delete_node_rate / d
|
||||
acr, dcr = add_connection_rate / d, delete_connection_rate / d
|
||||
|
||||
r = rand(r1)
|
||||
branch = 0
|
||||
branch = jnp.where(r <= anr, 1, branch)
|
||||
branch = jnp.where((anr < r) & (r <= anr + dnr), 2, branch)
|
||||
branch = jnp.where((anr + dnr < r) & (r <= anr + dnr + acr), 3, branch)
|
||||
branch = jnp.where((anr + dnr + acr) < r & r <= (anr + dnr + acr + dcr), 4, branch)
|
||||
nodes, connections = jax.lax.switch(branch, mutate_structure_li, (r2, nodes, connections))
|
||||
else:
|
||||
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
|
||||
|
||||
# mutate add node
|
||||
aux_nodes, aux_connections = m_add_node(r1, nodes, connections)
|
||||
nodes = jnp.where(rand(r1) < add_node_rate, aux_nodes, nodes)
|
||||
connections = jnp.where(rand(r1) < add_node_rate, aux_connections, connections)
|
||||
|
||||
# mutate delete node
|
||||
aux_nodes, aux_connections = m_delete_node(r2, nodes, connections)
|
||||
nodes = jnp.where(rand(r2) < delete_node_rate, aux_nodes, nodes)
|
||||
connections = jnp.where(rand(r2) < delete_node_rate, aux_connections, connections)
|
||||
|
||||
# mutate add connection
|
||||
aux_nodes, aux_connections = m_add_connection(r3, nodes, connections)
|
||||
nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes)
|
||||
connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections)
|
||||
|
||||
# mutate delete connection
|
||||
aux_nodes, aux_connections = m_delete_connection(r4, nodes, connections)
|
||||
nodes = jnp.where(rand(r4) < delete_connection_rate, aux_nodes, nodes)
|
||||
connections = jnp.where(rand(r4) < delete_connection_rate, aux_connections, connections)
|
||||
# mutate add connection
|
||||
aux_nodes, aux_connections = m_add_connection(r3, nodes, connections)
|
||||
nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes)
|
||||
connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections)
|
||||
|
||||
nodes, connections = mutate_values(rand_key, nodes, connections, bias_mean, bias_std, bias_mutate_strength,
|
||||
bias_mutate_rate, bias_replace_rate, response_mean, response_std,
|
||||
@@ -379,9 +341,9 @@ def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array,
|
||||
# randomly choose two nodes
|
||||
k1, k2 = jax.random.split(rand_key, num=2)
|
||||
i_key, from_idx = choice_node_key(k1, nodes, input_keys, output_keys,
|
||||
allow_input_keys=True, allow_output_keys=True)
|
||||
allow_input_keys=True, allow_output_keys=True)
|
||||
o_key, to_idx = choice_node_key(k2, nodes, input_keys, output_keys,
|
||||
allow_input_keys=False, allow_output_keys=True)
|
||||
allow_input_keys=False, allow_output_keys=True)
|
||||
|
||||
con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user