show the node cnt and conn cnt in pipeline;
mutate add node or add conn will not happen when there is no enough space for new nodes or conns.
This commit is contained in:
@@ -37,6 +37,10 @@ class DefaultMutation(BaseMutation):
|
|||||||
return nodes, conns
|
return nodes, conns
|
||||||
|
|
||||||
def mutate_structure(self, state, randkey, genome, nodes, conns, new_node_key):
|
def mutate_structure(self, state, randkey, genome, nodes, conns, new_node_key):
|
||||||
|
|
||||||
|
remain_node_space = jnp.isnan(nodes[:, 0]).sum()
|
||||||
|
remain_conn_space = jnp.isnan(conns[:, 0]).sum()
|
||||||
|
|
||||||
def mutate_add_node(key_, nodes_, conns_):
|
def mutate_add_node(key_, nodes_, conns_):
|
||||||
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
|
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
|
||||||
|
|
||||||
@@ -68,7 +72,7 @@ class DefaultMutation(BaseMutation):
|
|||||||
return new_nodes, new_conns
|
return new_nodes, new_conns
|
||||||
|
|
||||||
return jax.lax.cond(
|
return jax.lax.cond(
|
||||||
idx == I_INF,
|
(idx == I_INF) & (remain_node_space < 1) & (remain_conn_space < 2),
|
||||||
lambda: (nodes_, conns_), # do nothing
|
lambda: (nodes_, conns_), # do nothing
|
||||||
successful_add_node,
|
successful_add_node,
|
||||||
)
|
)
|
||||||
@@ -150,7 +154,9 @@ class DefaultMutation(BaseMutation):
|
|||||||
return jax.lax.cond(
|
return jax.lax.cond(
|
||||||
is_already_exist,
|
is_already_exist,
|
||||||
already_exist,
|
already_exist,
|
||||||
lambda: jax.lax.cond(is_cycle, nothing, successful),
|
lambda: jax.lax.cond(
|
||||||
|
is_cycle & (remain_conn_space < 1), nothing, successful
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
elif genome.network_type == "recurrent":
|
elif genome.network_type == "recurrent":
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ if __name__ == "__main__":
|
|||||||
genome=DefaultGenome(
|
genome=DefaultGenome(
|
||||||
num_inputs=3,
|
num_inputs=3,
|
||||||
num_outputs=1,
|
num_outputs=1,
|
||||||
max_nodes=100,
|
max_nodes=5,
|
||||||
max_conns=200,
|
max_conns=10,
|
||||||
node_gene=DefaultNodeGene(
|
node_gene=DefaultNodeGene(
|
||||||
activation_default=Act.tanh,
|
activation_default=Act.tanh,
|
||||||
activation_options=(Act.tanh,),
|
activation_options=(Act.tanh,),
|
||||||
|
|||||||
@@ -173,8 +173,27 @@ class Pipeline:
|
|||||||
member_count = jax.device_get(self.algorithm.member_count(state))
|
member_count = jax.device_get(self.algorithm.member_count(state))
|
||||||
species_sizes = [int(i) for i in member_count if i > 0]
|
species_sizes = [int(i) for i in member_count if i > 0]
|
||||||
|
|
||||||
|
pop = jax.device_get(pop)
|
||||||
|
pop_nodes, pop_conns = pop # (P, N, NL), (P, C, CL)
|
||||||
|
nodes_cnt = (~np.isnan(pop_nodes[:, :, 0])).sum(axis=1) # (P,)
|
||||||
|
conns_cnt = (~np.isnan(pop_conns[:, :, 0])).sum(axis=1) # (P,)
|
||||||
|
|
||||||
|
max_node_cnt, min_node_cnt, mean_node_cnt = (
|
||||||
|
max(nodes_cnt),
|
||||||
|
min(nodes_cnt),
|
||||||
|
np.mean(nodes_cnt),
|
||||||
|
)
|
||||||
|
|
||||||
|
max_conn_cnt, min_conn_cnt, mean_conn_cnt = (
|
||||||
|
max(conns_cnt),
|
||||||
|
min(conns_cnt),
|
||||||
|
np.mean(conns_cnt),
|
||||||
|
)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Generation: {self.algorithm.generation(state)}, Cost time: {cost_time * 1000:.2f}ms\n",
|
f"Generation: {self.algorithm.generation(state)}, Cost time: {cost_time * 1000:.2f}ms\n",
|
||||||
|
f"\tnode counts: max: {max_node_cnt}, min: {min_node_cnt}, mean: {mean_node_cnt:.2f}\n",
|
||||||
|
f"\tconn counts: max: {max_conn_cnt}, min: {min_conn_cnt}, mean: {mean_conn_cnt:.2f}\n",
|
||||||
f"\tspecies: {len(species_sizes)}, {species_sizes}\n",
|
f"\tspecies: {len(species_sizes)}, {species_sizes}\n",
|
||||||
f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
|
f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user