show the node cnt and conn cnt in pipeline;

mutate add node or add conn will not happen when there is no enough space for new nodes or conns.
This commit is contained in:
wls2002
2024-05-31 16:07:23 +08:00
parent 47b1cacb57
commit 3a7d05f133
3 changed files with 29 additions and 4 deletions

View File

@@ -37,6 +37,10 @@ class DefaultMutation(BaseMutation):
return nodes, conns
def mutate_structure(self, state, randkey, genome, nodes, conns, new_node_key):
remain_node_space = jnp.isnan(nodes[:, 0]).sum()
remain_conn_space = jnp.isnan(conns[:, 0]).sum()
def mutate_add_node(key_, nodes_, conns_):
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
@@ -68,7 +72,7 @@ class DefaultMutation(BaseMutation):
return new_nodes, new_conns
return jax.lax.cond(
idx == I_INF,
(idx == I_INF) & (remain_node_space < 1) & (remain_conn_space < 2),
lambda: (nodes_, conns_), # do nothing
successful_add_node,
)
@@ -150,7 +154,9 @@ class DefaultMutation(BaseMutation):
return jax.lax.cond(
is_already_exist,
already_exist,
lambda: jax.lax.cond(is_cycle, nothing, successful),
lambda: jax.lax.cond(
is_cycle & (remain_conn_space < 1), nothing, successful
),
)
elif genome.network_type == "recurrent":

View File

@@ -11,8 +11,8 @@ if __name__ == "__main__":
genome=DefaultGenome(
num_inputs=3,
num_outputs=1,
max_nodes=100,
max_conns=200,
max_nodes=5,
max_conns=10,
node_gene=DefaultNodeGene(
activation_default=Act.tanh,
activation_options=(Act.tanh,),

View File

@@ -173,8 +173,27 @@ class Pipeline:
member_count = jax.device_get(self.algorithm.member_count(state))
species_sizes = [int(i) for i in member_count if i > 0]
pop = jax.device_get(pop)
pop_nodes, pop_conns = pop # (P, N, NL), (P, C, CL)
nodes_cnt = (~np.isnan(pop_nodes[:, :, 0])).sum(axis=1) # (P,)
conns_cnt = (~np.isnan(pop_conns[:, :, 0])).sum(axis=1) # (P,)
max_node_cnt, min_node_cnt, mean_node_cnt = (
max(nodes_cnt),
min(nodes_cnt),
np.mean(nodes_cnt),
)
max_conn_cnt, min_conn_cnt, mean_conn_cnt = (
max(conns_cnt),
min(conns_cnt),
np.mean(conns_cnt),
)
print(
f"Generation: {self.algorithm.generation(state)}, Cost time: {cost_time * 1000:.2f}ms\n",
f"\tnode counts: max: {max_node_cnt}, min: {min_node_cnt}, mean: {mean_node_cnt:.2f}\n",
f"\tconn counts: max: {max_conn_cnt}, min: {min_conn_cnt}, mean: {mean_conn_cnt:.2f}\n",
f"\tspecies: {len(species_sizes)}, {species_sizes}\n",
f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
)