remove attr enable for conn

This commit is contained in:
wls2002
2024-05-31 22:06:25 +08:00
parent d6e9ff5d9a
commit 4ad9f0a85a
9 changed files with 43 additions and 108 deletions

View File

@@ -45,8 +45,8 @@ class DefaultMutation(BaseMutation):
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
def successful_add_node():
# disable the connection
new_conns = conns_.at[idx, 2].set(False)
# remove the original connection
new_conns = delete_conn_by_pos(conns_, idx)
# add a new node
new_nodes = add_node(
@@ -58,14 +58,12 @@ class DefaultMutation(BaseMutation):
new_conns,
i_key,
new_node_key,
True,
genome.conn_gene.new_custom_attrs(state),
)
new_conns = add_conn(
new_conns,
new_node_key,
o_key,
True,
genome.conn_gene.new_custom_attrs(state),
)
@@ -140,27 +138,26 @@ class DefaultMutation(BaseMutation):
def successful():
return nodes_, add_conn(
conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs(state)
conns_, i_key, o_key, genome.conn_gene.new_custom_attrs(state)
)
def already_exist():
return nodes_, conns_.at[conn_pos, 2].set(True)
if genome.network_type == "feedforward":
u_cons = unflatten_conns(nodes_, conns_)
cons_exist = ~jnp.isnan(u_cons[0, :, :])
is_cycle = check_cycles(nodes_, cons_exist, from_idx, to_idx)
conns_exist = ~jnp.isnan(u_cons[0, :, :])
is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx)
return jax.lax.cond(
is_already_exist,
already_exist,
lambda: jax.lax.cond(
is_cycle & (remain_conn_space < 1), nothing, successful
),
is_already_exist | is_cycle | (remain_conn_space < 1),
nothing,
successful,
)
elif genome.network_type == "recurrent":
return jax.lax.cond(is_already_exist, already_exist, successful)
return jax.lax.cond(
is_already_exist | (remain_conn_space < 1),
nothing,
successful,
)
else:
raise ValueError(f"Invalid network type: {genome.network_type}")
@@ -169,19 +166,16 @@ class DefaultMutation(BaseMutation):
# randomly choose a connection
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
def successfully_delete_connection():
return nodes_, delete_conn_by_pos(conns_, idx)
return jax.lax.cond(
idx == I_INF,
lambda: (nodes_, conns_), # nothing
successfully_delete_connection,
lambda: (nodes_, delete_conn_by_pos(conns_, idx)), # success
)
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
def no(key_, nodes_, conns_):
def no(_, nodes_, conns_):
return nodes_, conns_
if self.node_add > 0: