optimize import

This commit is contained in:
wls2002
2023-06-29 09:41:49 +08:00
parent d28cef1a87
commit 01b7731231
14 changed files with 29 additions and 58 deletions

View File

@@ -6,10 +6,7 @@ import jax
from jax import jit, vmap
from configs import Configer
from algorithms.neat import initialize_genomes
from algorithms.neat.population import create_next_generation, speciate, update_species
from algorithms.neat import unflatten_connections, topological_sort, create_forward_function
from algorithms import neat
class Pipeline:
@@ -32,7 +29,7 @@ class Pipeline:
self.generation = 0
self.best_genome = None
self.pop_nodes, self.pop_cons = initialize_genomes(self.N, self.C, self.config)
self.pop_nodes, self.pop_cons = neat.initialize_genomes(self.N, self.C, self.config)
self.species_info = np.full((self.S, 3), np.nan)
self.species_info[0, :] = 0, -np.inf, 0
self.idx2species = np.zeros(self.P, dtype=np.float32)
@@ -47,9 +44,9 @@ class Pipeline:
self.evaluate_time = 0
self.pop_unflatten_connections = jit(vmap(unflatten_connections))
self.pop_topological_sort = jit(vmap(topological_sort))
self.forward = create_forward_function(config)
self.pop_unflatten_connections = jit(vmap(neat.unflatten_connections))
self.pop_topological_sort = jit(vmap(neat.topological_sort))
self.forward = neat.create_forward_function(config)
def ask(self):
"""
@@ -84,13 +81,13 @@ class Pipeline:
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
self.species_info, self.center_nodes, self.center_cons, winner, loser, elite_mask = \
update_species(k1, fitnesses, self.species_info, self.idx2species, self.center_nodes,
self.center_cons, self.generation, self.jit_config)
neat.update_species(k1, fitnesses, self.species_info, self.idx2species, self.center_nodes,
self.center_cons, self.generation, self.jit_config)
self.pop_nodes, self.pop_cons = create_next_generation(k2, self.pop_nodes, self.pop_cons, winner, loser,
elite_mask, self.generation, self.jit_config)
self.pop_nodes, self.pop_cons = neat.create_next_generation(k2, self.pop_nodes, self.pop_cons, winner, loser,
elite_mask, self.generation, self.jit_config)
self.idx2species, self.center_nodes, self.center_cons, self.species_info = speciate(
self.idx2species, self.center_nodes, self.center_cons, self.species_info = neat.speciate(
self.pop_nodes, self.pop_cons, self.species_info, self.center_nodes, self.center_cons, self.generation,
self.jit_config)