remove attr enable for conn
This commit is contained in:
@@ -45,8 +45,8 @@ class DefaultMutation(BaseMutation):
|
||||
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
|
||||
|
||||
def successful_add_node():
|
||||
# disable the connection
|
||||
new_conns = conns_.at[idx, 2].set(False)
|
||||
# remove the original connection
|
||||
new_conns = delete_conn_by_pos(conns_, idx)
|
||||
|
||||
# add a new node
|
||||
new_nodes = add_node(
|
||||
@@ -58,14 +58,12 @@ class DefaultMutation(BaseMutation):
|
||||
new_conns,
|
||||
i_key,
|
||||
new_node_key,
|
||||
True,
|
||||
genome.conn_gene.new_custom_attrs(state),
|
||||
)
|
||||
new_conns = add_conn(
|
||||
new_conns,
|
||||
new_node_key,
|
||||
o_key,
|
||||
True,
|
||||
genome.conn_gene.new_custom_attrs(state),
|
||||
)
|
||||
|
||||
@@ -140,27 +138,26 @@ class DefaultMutation(BaseMutation):
|
||||
|
||||
def successful():
|
||||
return nodes_, add_conn(
|
||||
conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs(state)
|
||||
conns_, i_key, o_key, genome.conn_gene.new_custom_attrs(state)
|
||||
)
|
||||
|
||||
def already_exist():
|
||||
return nodes_, conns_.at[conn_pos, 2].set(True)
|
||||
|
||||
if genome.network_type == "feedforward":
|
||||
u_cons = unflatten_conns(nodes_, conns_)
|
||||
cons_exist = ~jnp.isnan(u_cons[0, :, :])
|
||||
is_cycle = check_cycles(nodes_, cons_exist, from_idx, to_idx)
|
||||
conns_exist = ~jnp.isnan(u_cons[0, :, :])
|
||||
is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx)
|
||||
|
||||
return jax.lax.cond(
|
||||
is_already_exist,
|
||||
already_exist,
|
||||
lambda: jax.lax.cond(
|
||||
is_cycle & (remain_conn_space < 1), nothing, successful
|
||||
),
|
||||
is_already_exist | is_cycle | (remain_conn_space < 1),
|
||||
nothing,
|
||||
successful,
|
||||
)
|
||||
|
||||
elif genome.network_type == "recurrent":
|
||||
return jax.lax.cond(is_already_exist, already_exist, successful)
|
||||
return jax.lax.cond(
|
||||
is_already_exist | (remain_conn_space < 1),
|
||||
nothing,
|
||||
successful,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid network type: {genome.network_type}")
|
||||
@@ -169,19 +166,16 @@ class DefaultMutation(BaseMutation):
|
||||
# randomly choose a connection
|
||||
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
|
||||
|
||||
def successfully_delete_connection():
|
||||
return nodes_, delete_conn_by_pos(conns_, idx)
|
||||
|
||||
return jax.lax.cond(
|
||||
idx == I_INF,
|
||||
lambda: (nodes_, conns_), # nothing
|
||||
successfully_delete_connection,
|
||||
lambda: (nodes_, delete_conn_by_pos(conns_, idx)), # success
|
||||
)
|
||||
|
||||
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
|
||||
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
|
||||
|
||||
def no(key_, nodes_, conns_):
|
||||
def no(_, nodes_, conns_):
|
||||
return nodes_, conns_
|
||||
|
||||
if self.node_add > 0:
|
||||
|
||||
Reference in New Issue
Block a user