optimize import
This commit is contained in:
@@ -6,10 +6,7 @@ import jax
|
||||
from jax import jit, vmap
|
||||
|
||||
from configs import Configer
|
||||
from algorithms.neat import initialize_genomes
|
||||
|
||||
from algorithms.neat.population import create_next_generation, speciate, update_species
|
||||
from algorithms.neat import unflatten_connections, topological_sort, create_forward_function
|
||||
from algorithms import neat
|
||||
|
||||
|
||||
class Pipeline:
|
||||
@@ -32,7 +29,7 @@ class Pipeline:
|
||||
self.generation = 0
|
||||
self.best_genome = None
|
||||
|
||||
self.pop_nodes, self.pop_cons = initialize_genomes(self.N, self.C, self.config)
|
||||
self.pop_nodes, self.pop_cons = neat.initialize_genomes(self.N, self.C, self.config)
|
||||
self.species_info = np.full((self.S, 3), np.nan)
|
||||
self.species_info[0, :] = 0, -np.inf, 0
|
||||
self.idx2species = np.zeros(self.P, dtype=np.float32)
|
||||
@@ -47,9 +44,9 @@ class Pipeline:
|
||||
|
||||
self.evaluate_time = 0
|
||||
|
||||
self.pop_unflatten_connections = jit(vmap(unflatten_connections))
|
||||
self.pop_topological_sort = jit(vmap(topological_sort))
|
||||
self.forward = create_forward_function(config)
|
||||
self.pop_unflatten_connections = jit(vmap(neat.unflatten_connections))
|
||||
self.pop_topological_sort = jit(vmap(neat.topological_sort))
|
||||
self.forward = neat.create_forward_function(config)
|
||||
|
||||
def ask(self):
|
||||
"""
|
||||
@@ -84,13 +81,13 @@ class Pipeline:
|
||||
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
|
||||
|
||||
self.species_info, self.center_nodes, self.center_cons, winner, loser, elite_mask = \
|
||||
update_species(k1, fitnesses, self.species_info, self.idx2species, self.center_nodes,
|
||||
self.center_cons, self.generation, self.jit_config)
|
||||
neat.update_species(k1, fitnesses, self.species_info, self.idx2species, self.center_nodes,
|
||||
self.center_cons, self.generation, self.jit_config)
|
||||
|
||||
self.pop_nodes, self.pop_cons = create_next_generation(k2, self.pop_nodes, self.pop_cons, winner, loser,
|
||||
elite_mask, self.generation, self.jit_config)
|
||||
self.pop_nodes, self.pop_cons = neat.create_next_generation(k2, self.pop_nodes, self.pop_cons, winner, loser,
|
||||
elite_mask, self.generation, self.jit_config)
|
||||
|
||||
self.idx2species, self.center_nodes, self.center_cons, self.species_info = speciate(
|
||||
self.idx2species, self.center_nodes, self.center_cons, self.species_info = neat.speciate(
|
||||
self.pop_nodes, self.pop_cons, self.species_info, self.center_nodes, self.center_cons, self.generation,
|
||||
self.jit_config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user