modify NEAT package; successfully run xor example

This commit is contained in:
root
2024-07-11 10:10:16 +08:00
parent 52d5f046d3
commit 4a631f9464
14 changed files with 420 additions and 502 deletions

View File

@@ -5,10 +5,10 @@ import jax, jax.numpy as jnp
import datetime, time
import numpy as np
from algorithm import BaseAlgorithm
from problem import BaseProblem
from problem.rl_env import RLEnv
from problem.func_fit import FuncFit
from tensorneat.algorithm import BaseAlgorithm
from tensorneat.problem import BaseProblem
from tensorneat.problem.rl_env import RLEnv
from tensorneat.problem.func_fit import FuncFit
from tensorneat.common import State, StatefulBaseClass
@@ -187,7 +187,7 @@ class Pipeline(StatefulBaseClass):
print("Fitness limit reached!")
break
if self.algorithm.generation(state) >= self.generation_limit:
if int(state.generation) >= self.generation_limit:
print("Generation limit reached!")
if self.is_save:
@@ -203,6 +203,8 @@ class Pipeline(StatefulBaseClass):
return state, self.best_genome
def analysis(self, state, pop, fitnesses):
generation = int(state.generation)
valid_fitnesses = fitnesses[~np.isinf(fitnesses)]
@@ -223,8 +225,12 @@ class Pipeline(StatefulBaseClass):
self.best_genome = pop[0][max_idx], pop[1][max_idx]
if self.is_save:
# save best
best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx]))
with open(os.path.join(self.genome_dir, f"{int(self.algorithm.generation(state))}.npz"), "wb") as f:
file_name = os.path.join(
self.genome_dir, f"{generation}.npz"
)
with open(file_name, "wb") as f:
np.savez(
f,
nodes=best_genome[0],
@@ -232,42 +238,18 @@ class Pipeline(StatefulBaseClass):
fitness=self.best_fitness,
)
# save best if save path is not None
member_count = jax.device_get(self.algorithm.member_count(state))
species_sizes = [int(i) for i in member_count if i > 0]
pop = jax.device_get(pop)
pop_nodes, pop_conns = pop # (P, N, NL), (P, C, CL)
nodes_cnt = (~np.isnan(pop_nodes[:, :, 0])).sum(axis=1) # (P,)
conns_cnt = (~np.isnan(pop_conns[:, :, 0])).sum(axis=1) # (P,)
max_node_cnt, min_node_cnt, mean_node_cnt = (
max(nodes_cnt),
min(nodes_cnt),
np.mean(nodes_cnt),
)
max_conn_cnt, min_conn_cnt, mean_conn_cnt = (
max(conns_cnt),
min(conns_cnt),
np.mean(conns_cnt),
)
# append log
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
f.write(
f"{generation},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
)
print(
f"Generation: {self.algorithm.generation(state)}, Cost time: {cost_time * 1000:.2f}ms\n",
f"\tnode counts: max: {max_node_cnt}, min: {min_node_cnt}, mean: {mean_node_cnt:.2f}\n",
f"\tconn counts: max: {max_conn_cnt}, min: {min_conn_cnt}, mean: {mean_conn_cnt:.2f}\n",
f"\tspecies: {len(species_sizes)}, {species_sizes}\n",
f"Generation: {generation}, Cost time: {cost_time * 1000:.2f}ms\n",
f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
)
# append log
if self.is_save:
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
f.write(
f"{self.algorithm.generation(state)},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
)
self.algorithm.show_details(state, fitnesses)
def show(self, state, best, *args, **kwargs):
transformed = self.algorithm.transform(state, best)