modify method cal_spawn_numbers

spawn_number = previous_size + (target_spawn_number - previous_size) * jit_config['spawn_number_move_rate']
This commit is contained in:
wls2002
2023-07-01 13:36:19 +08:00
parent 896082900a
commit f6dcb97df8
7 changed files with 64 additions and 21 deletions

View File

@@ -48,6 +48,26 @@ class Pipeline:
self.pop_topological_sort = jit(vmap(neat.topological_sort))
self.forward = neat.create_forward_function(config)
# fitness_lower = np.zeros(self.P, dtype=np.float32)
# randkey_lower = np.zeros(2, dtype=np.uint32)
# pop_nodes_lower = np.zeros((self.P, self.N, 5), dtype=np.float32)
# pop_cons_lower = np.zeros((self.P, self.C, 4), dtype=np.float32)
# species_info_lower = np.zeros((self.S, 4), dtype=np.float32)
# idx2species_lower = np.zeros(self.P, dtype=np.float32)
# center_nodes_lower = np.zeros((self.S, self.N, 5), dtype=np.float32)
# center_cons_lower = np.zeros((self.S, self.C, 4), dtype=np.float32)
#
# self.tell_func = jit(neat.tell).lower(fitness_lower,
# randkey_lower,
# pop_nodes_lower,
# pop_cons_lower,
# species_info_lower,
# idx2species_lower,
# center_nodes_lower,
# center_cons_lower,
# 0,
# self.jit_config).compile()
def ask(self):
"""
Creates a function that receives a genome and returns a forward function.
@@ -75,21 +95,12 @@ class Pipeline:
assert self.config['forward_way'] == 'common'
return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons)
def tell(self, fitnesses):
self.generation += 1
def tell(self, fitness):
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
self.species_info, self.center_nodes, self.center_cons, winner, loser, elite_mask = \
neat.update_species(k1, fitnesses, self.species_info, self.idx2species, self.center_nodes,
self.center_cons, self.generation, self.jit_config)
self.pop_nodes, self.pop_cons = neat.create_next_generation(k2, self.pop_nodes, self.pop_cons, winner, loser,
elite_mask, self.generation, self.jit_config)
self.idx2species, self.center_nodes, self.center_cons, self.species_info = neat.speciate(
self.pop_nodes, self.pop_cons, self.species_info, self.center_nodes, self.center_cons, self.generation,
self.jit_config)
self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \
self.center_cons, self.generation = neat.tell(fitness, self.randkey, self.pop_nodes, self.pop_cons,
self.species_info, self.idx2species, self.center_nodes,
self.center_cons, self.generation, self.jit_config)
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config['generation_limit']):