prepare for experiment

This commit is contained in:
wls2002
2023-05-14 15:27:17 +08:00
parent 72c9d4167a
commit 2b79f2c903
11 changed files with 252 additions and 62 deletions

View File

@@ -133,5 +133,8 @@ act_name2key = {
def act(idx, z):
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
return jax.lax.switch(idx, ACT_TOTAL_LIST, z)
res = jax.lax.switch(idx, ACT_TOTAL_LIST, z)
return jnp.where(jnp.isnan(res), jnp.nan, res)
# return jax.lax.switch(idx, ACT_TOTAL_LIST, z)

View File

@@ -88,6 +88,12 @@ def mutate(rand_key: Array,
def m_add_connection(rk, n, c):
return mutate_add_connection(rk, n, c, input_idx, output_idx)
def m_delete_node(rk, n, c):
return mutate_delete_node(rk, n, c, input_idx, output_idx)
def m_delete_connection(rk, n, c):
return mutate_delete_connection(rk, n, c)
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
# mutate add node
@@ -100,6 +106,16 @@ def mutate(rand_key: Array,
nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes)
connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections)
# mutate delete node
aux_nodes, aux_connections = m_delete_node(r2, nodes, connections)
nodes = jnp.where(rand(r2) < delete_node_rate, aux_nodes, nodes)
connections = jnp.where(rand(r2) < delete_node_rate, aux_connections, connections)
# mutate delete connection
aux_nodes, aux_connections = m_delete_connection(r4, nodes, connections)
nodes = jnp.where(rand(r4) < delete_connection_rate, aux_nodes, nodes)
connections = jnp.where(rand(r4) < delete_connection_rate, aux_connections, connections)
nodes, connections = mutate_values(rand_key, nodes, connections, bias_mean, bias_std, bias_mutate_strength,
bias_mutate_rate, bias_replace_rate, response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate,

View File

@@ -14,6 +14,8 @@ EMPTY_CON = jnp.full((1, 4), jnp.nan)
def unflatten_connections(nodes, cons):
"""
transform the (C, 4) connections to (2, N, N)
this function is only used for transform a genome to the forward function, so here we set the weight of un=enabled
connections to nan, that means we dont consider such connection when forward;
:param cons:
:param nodes:
:return:
@@ -29,6 +31,10 @@ def unflatten_connections(nodes, cons):
# however, it will do nothing set values in an array
res = res.at[0, i_idxs, o_idxs].set(cons[:, 2])
res = res.at[1, i_idxs, o_idxs].set(cons[:, 3])
# (2, N, N), (2, N, N), (2, N, N)
# res = jnp.where(res[1, :, :] == 0, jnp.nan, res)
return res