prepare for experiment
This commit is contained in:
@@ -133,5 +133,8 @@ act_name2key = {
|
||||
def act(idx, z):
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
# change idx from float to int
|
||||
return jax.lax.switch(idx, ACT_TOTAL_LIST, z)
|
||||
res = jax.lax.switch(idx, ACT_TOTAL_LIST, z)
|
||||
return jnp.where(jnp.isnan(res), jnp.nan, res)
|
||||
|
||||
# return jax.lax.switch(idx, ACT_TOTAL_LIST, z)
|
||||
|
||||
|
||||
@@ -88,6 +88,12 @@ def mutate(rand_key: Array,
|
||||
def m_add_connection(rk, n, c):
|
||||
return mutate_add_connection(rk, n, c, input_idx, output_idx)
|
||||
|
||||
def m_delete_node(rk, n, c):
|
||||
return mutate_delete_node(rk, n, c, input_idx, output_idx)
|
||||
|
||||
def m_delete_connection(rk, n, c):
|
||||
return mutate_delete_connection(rk, n, c)
|
||||
|
||||
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
|
||||
|
||||
# mutate add node
|
||||
@@ -100,6 +106,16 @@ def mutate(rand_key: Array,
|
||||
nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes)
|
||||
connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections)
|
||||
|
||||
# mutate delete node
|
||||
aux_nodes, aux_connections = m_delete_node(r2, nodes, connections)
|
||||
nodes = jnp.where(rand(r2) < delete_node_rate, aux_nodes, nodes)
|
||||
connections = jnp.where(rand(r2) < delete_node_rate, aux_connections, connections)
|
||||
|
||||
# mutate delete connection
|
||||
aux_nodes, aux_connections = m_delete_connection(r4, nodes, connections)
|
||||
nodes = jnp.where(rand(r4) < delete_connection_rate, aux_nodes, nodes)
|
||||
connections = jnp.where(rand(r4) < delete_connection_rate, aux_connections, connections)
|
||||
|
||||
nodes, connections = mutate_values(rand_key, nodes, connections, bias_mean, bias_std, bias_mutate_strength,
|
||||
bias_mutate_rate, bias_replace_rate, response_mean, response_std,
|
||||
response_mutate_strength, response_mutate_rate, response_replace_rate,
|
||||
|
||||
@@ -14,6 +14,8 @@ EMPTY_CON = jnp.full((1, 4), jnp.nan)
|
||||
def unflatten_connections(nodes, cons):
|
||||
"""
|
||||
transform the (C, 4) connections to (2, N, N)
|
||||
this function is only used for transform a genome to the forward function, so here we set the weight of un=enabled
|
||||
connections to nan, that means we dont consider such connection when forward;
|
||||
:param cons:
|
||||
:param nodes:
|
||||
:return:
|
||||
@@ -29,6 +31,10 @@ def unflatten_connections(nodes, cons):
|
||||
# however, it will do nothing set values in an array
|
||||
res = res.at[0, i_idxs, o_idxs].set(cons[:, 2])
|
||||
res = res.at[1, i_idxs, o_idxs].set(cons[:, 3])
|
||||
|
||||
# (2, N, N), (2, N, N), (2, N, N)
|
||||
# res = jnp.where(res[1, :, :] == 0, jnp.nan, res)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user