modify method cal_spawn_numbers
spawn_number = previous_size + (target_spawn_number - previous_size) * jit_config['spawn_number_move_rate']
This commit is contained in:
39
pipeline.py
39
pipeline.py
@@ -48,6 +48,26 @@ class Pipeline:
|
||||
self.pop_topological_sort = jit(vmap(neat.topological_sort))
|
||||
self.forward = neat.create_forward_function(config)
|
||||
|
||||
# fitness_lower = np.zeros(self.P, dtype=np.float32)
|
||||
# randkey_lower = np.zeros(2, dtype=np.uint32)
|
||||
# pop_nodes_lower = np.zeros((self.P, self.N, 5), dtype=np.float32)
|
||||
# pop_cons_lower = np.zeros((self.P, self.C, 4), dtype=np.float32)
|
||||
# species_info_lower = np.zeros((self.S, 4), dtype=np.float32)
|
||||
# idx2species_lower = np.zeros(self.P, dtype=np.float32)
|
||||
# center_nodes_lower = np.zeros((self.S, self.N, 5), dtype=np.float32)
|
||||
# center_cons_lower = np.zeros((self.S, self.C, 4), dtype=np.float32)
|
||||
#
|
||||
# self.tell_func = jit(neat.tell).lower(fitness_lower,
|
||||
# randkey_lower,
|
||||
# pop_nodes_lower,
|
||||
# pop_cons_lower,
|
||||
# species_info_lower,
|
||||
# idx2species_lower,
|
||||
# center_nodes_lower,
|
||||
# center_cons_lower,
|
||||
# 0,
|
||||
# self.jit_config).compile()
|
||||
|
||||
def ask(self):
|
||||
"""
|
||||
Creates a function that receives a genome and returns a forward function.
|
||||
@@ -75,21 +95,12 @@ class Pipeline:
|
||||
assert self.config['forward_way'] == 'common'
|
||||
return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons)
|
||||
|
||||
def tell(self, fitnesses):
|
||||
self.generation += 1
|
||||
def tell(self, fitness):
|
||||
|
||||
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
|
||||
|
||||
self.species_info, self.center_nodes, self.center_cons, winner, loser, elite_mask = \
|
||||
neat.update_species(k1, fitnesses, self.species_info, self.idx2species, self.center_nodes,
|
||||
self.center_cons, self.generation, self.jit_config)
|
||||
|
||||
self.pop_nodes, self.pop_cons = neat.create_next_generation(k2, self.pop_nodes, self.pop_cons, winner, loser,
|
||||
elite_mask, self.generation, self.jit_config)
|
||||
|
||||
self.idx2species, self.center_nodes, self.center_cons, self.species_info = neat.speciate(
|
||||
self.pop_nodes, self.pop_cons, self.species_info, self.center_nodes, self.center_cons, self.generation,
|
||||
self.jit_config)
|
||||
self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \
|
||||
self.center_cons, self.generation = neat.tell(fitness, self.randkey, self.pop_nodes, self.pop_cons,
|
||||
self.species_info, self.idx2species, self.center_nodes,
|
||||
self.center_cons, self.generation, self.jit_config)
|
||||
|
||||
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
||||
for _ in range(self.config['generation_limit']):
|
||||
|
||||
Reference in New Issue
Block a user