diff --git a/tensorneat/algorithm/neat/ga/mutation/default.py b/tensorneat/algorithm/neat/ga/mutation/default.py index 81252bd..54f492f 100644 --- a/tensorneat/algorithm/neat/ga/mutation/default.py +++ b/tensorneat/algorithm/neat/ga/mutation/default.py @@ -37,6 +37,10 @@ class DefaultMutation(BaseMutation): return nodes, conns def mutate_structure(self, state, randkey, genome, nodes, conns, new_node_key): + + remain_node_space = jnp.isnan(nodes[:, 0]).sum() + remain_conn_space = jnp.isnan(conns[:, 0]).sum() + def mutate_add_node(key_, nodes_, conns_): i_key, o_key, idx = self.choice_connection_key(key_, conns_) @@ -68,7 +72,7 @@ class DefaultMutation(BaseMutation): return new_nodes, new_conns return jax.lax.cond( - idx == I_INF, + (idx == I_INF) & (remain_node_space < 1) & (remain_conn_space < 2), lambda: (nodes_, conns_), # do nothing successful_add_node, ) @@ -150,7 +154,9 @@ class DefaultMutation(BaseMutation): return jax.lax.cond( is_already_exist, already_exist, - lambda: jax.lax.cond(is_cycle, nothing, successful), + lambda: jax.lax.cond( + is_cycle & (remain_conn_space < 1), nothing, successful + ), ) elif genome.network_type == "recurrent": diff --git a/tensorneat/examples/func_fit/xor.py b/tensorneat/examples/func_fit/xor.py index aaf83a2..652c68c 100644 --- a/tensorneat/examples/func_fit/xor.py +++ b/tensorneat/examples/func_fit/xor.py @@ -11,8 +11,8 @@ if __name__ == "__main__": genome=DefaultGenome( num_inputs=3, num_outputs=1, - max_nodes=100, - max_conns=200, + max_nodes=5, + max_conns=10, node_gene=DefaultNodeGene( activation_default=Act.tanh, activation_options=(Act.tanh,), diff --git a/tensorneat/pipeline.py b/tensorneat/pipeline.py index 6e37744..caa5afa 100644 --- a/tensorneat/pipeline.py +++ b/tensorneat/pipeline.py @@ -173,8 +173,27 @@ class Pipeline: member_count = jax.device_get(self.algorithm.member_count(state)) species_sizes = [int(i) for i in member_count if i > 0] + pop = jax.device_get(pop) + pop_nodes, pop_conns = pop # (P, N, NL), (P, C, CL) + nodes_cnt = (~np.isnan(pop_nodes[:, :, 0])).sum(axis=1) # (P,) + conns_cnt = (~np.isnan(pop_conns[:, :, 0])).sum(axis=1) # (P,) + + max_node_cnt, min_node_cnt, mean_node_cnt = ( + max(nodes_cnt), + min(nodes_cnt), + np.mean(nodes_cnt), + ) + + max_conn_cnt, min_conn_cnt, mean_conn_cnt = ( + max(conns_cnt), + min(conns_cnt), + np.mean(conns_cnt), + ) + print( f"Generation: {self.algorithm.generation(state)}, Cost time: {cost_time * 1000:.2f}ms\n", + f"\tnode counts: max: {max_node_cnt}, min: {min_node_cnt}, mean: {mean_node_cnt:.2f}\n", + f"\tconn counts: max: {max_conn_cnt}, min: {min_conn_cnt}, mean: {mean_conn_cnt:.2f}\n", f"\tspecies: {len(species_sizes)}, {species_sizes}\n", f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n", )