use jit().lower.compile in create functions

This commit is contained in:
wls2002
2023-05-08 02:35:04 +08:00
parent 497d89fc69
commit d4a75b9394
9 changed files with 120 additions and 77 deletions

View File

@@ -28,10 +28,8 @@ class Pipeline:
self.species_controller = SpeciesController(config)
self.initialize_func = create_initialize_function(config)
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
self.mutate_func = create_mutate_function(config, self.input_idx, self.output_idx, batch=True)
self.crossover_func = create_crossover_function(batch=True)
self.o2o_distance = create_distance_function(self.config, type='o2o')
self.o2m_distance = create_distance_function(self.config, type='o2m')
self.compile_functions()
self.generation = 0
self.species_controller.speciate(self.pop_nodes, self.pop_connections,
@@ -142,6 +140,15 @@ class Pipeline:
for s in self.species_controller.species.values():
s.representative = expand_single(*s.representative, self.N)
# update functions
self.compile_functions()
def compile_functions(self):
self.mutate_func = create_mutate_function(self.N, self.config, batch=True)
self.crossover_func = create_crossover_function(self.N, self.config, batch=True)
self.o2o_distance = create_distance_function(self.N, self.config, type='o2o')
self.o2m_distance = create_distance_function(self.N, self.config, type='o2m')
def default_analysis(self, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
species_sizes = [len(s.members) for s in self.species_controller.species.values()]