modify NEAT package; successfully run xor example
This commit is contained in:
@@ -5,10 +5,10 @@ import jax, jax.numpy as jnp
|
||||
import datetime, time
|
||||
import numpy as np
|
||||
|
||||
from algorithm import BaseAlgorithm
|
||||
from problem import BaseProblem
|
||||
from problem.rl_env import RLEnv
|
||||
from problem.func_fit import FuncFit
|
||||
from tensorneat.algorithm import BaseAlgorithm
|
||||
from tensorneat.problem import BaseProblem
|
||||
from tensorneat.problem.rl_env import RLEnv
|
||||
from tensorneat.problem.func_fit import FuncFit
|
||||
from tensorneat.common import State, StatefulBaseClass
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class Pipeline(StatefulBaseClass):
|
||||
print("Fitness limit reached!")
|
||||
break
|
||||
|
||||
if self.algorithm.generation(state) >= self.generation_limit:
|
||||
if int(state.generation) >= self.generation_limit:
|
||||
print("Generation limit reached!")
|
||||
|
||||
if self.is_save:
|
||||
@@ -203,6 +203,8 @@ class Pipeline(StatefulBaseClass):
|
||||
return state, self.best_genome
|
||||
|
||||
def analysis(self, state, pop, fitnesses):
|
||||
|
||||
generation = int(state.generation)
|
||||
|
||||
valid_fitnesses = fitnesses[~np.isinf(fitnesses)]
|
||||
|
||||
@@ -223,8 +225,12 @@ class Pipeline(StatefulBaseClass):
|
||||
self.best_genome = pop[0][max_idx], pop[1][max_idx]
|
||||
|
||||
if self.is_save:
|
||||
# save best
|
||||
best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx]))
|
||||
with open(os.path.join(self.genome_dir, f"{int(self.algorithm.generation(state))}.npz"), "wb") as f:
|
||||
file_name = os.path.join(
|
||||
self.genome_dir, f"{generation}.npz"
|
||||
)
|
||||
with open(file_name, "wb") as f:
|
||||
np.savez(
|
||||
f,
|
||||
nodes=best_genome[0],
|
||||
@@ -232,42 +238,18 @@ class Pipeline(StatefulBaseClass):
|
||||
fitness=self.best_fitness,
|
||||
)
|
||||
|
||||
# save best if save path is not None
|
||||
|
||||
member_count = jax.device_get(self.algorithm.member_count(state))
|
||||
species_sizes = [int(i) for i in member_count if i > 0]
|
||||
|
||||
pop = jax.device_get(pop)
|
||||
pop_nodes, pop_conns = pop # (P, N, NL), (P, C, CL)
|
||||
nodes_cnt = (~np.isnan(pop_nodes[:, :, 0])).sum(axis=1) # (P,)
|
||||
conns_cnt = (~np.isnan(pop_conns[:, :, 0])).sum(axis=1) # (P,)
|
||||
|
||||
max_node_cnt, min_node_cnt, mean_node_cnt = (
|
||||
max(nodes_cnt),
|
||||
min(nodes_cnt),
|
||||
np.mean(nodes_cnt),
|
||||
)
|
||||
|
||||
max_conn_cnt, min_conn_cnt, mean_conn_cnt = (
|
||||
max(conns_cnt),
|
||||
min(conns_cnt),
|
||||
np.mean(conns_cnt),
|
||||
)
|
||||
# append log
|
||||
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
|
||||
f.write(
|
||||
f"{generation},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Generation: {self.algorithm.generation(state)}, Cost time: {cost_time * 1000:.2f}ms\n",
|
||||
f"\tnode counts: max: {max_node_cnt}, min: {min_node_cnt}, mean: {mean_node_cnt:.2f}\n",
|
||||
f"\tconn counts: max: {max_conn_cnt}, min: {min_conn_cnt}, mean: {mean_conn_cnt:.2f}\n",
|
||||
f"\tspecies: {len(species_sizes)}, {species_sizes}\n",
|
||||
f"Generation: {generation}, Cost time: {cost_time * 1000:.2f}ms\n",
|
||||
f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
|
||||
)
|
||||
|
||||
# append log
|
||||
if self.is_save:
|
||||
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
|
||||
f.write(
|
||||
f"{self.algorithm.generation(state)},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
|
||||
)
|
||||
self.algorithm.show_details(state, fitnesses)
|
||||
|
||||
def show(self, state, best, *args, **kwargs):
|
||||
transformed = self.algorithm.transform(state, best)
|
||||
|
||||
Reference in New Issue
Block a user