prepare for experiment

This commit is contained in:
wls2002
2023-05-14 15:27:17 +08:00
parent 72c9d4167a
commit 2b79f2c903
11 changed files with 252 additions and 62 deletions

View File

@@ -16,9 +16,9 @@ class Pipeline:
Neat algorithm pipeline.
"""
def __init__(self, config, seed=42):
def __init__(self, config, function_factory, seed=42):
self.time_dict = {}
self.function_factory = FunctionFactory(config)
self.function_factory = function_factory
self.randkey = jax.random.PRNGKey(seed)
np.random.seed(seed)
@@ -31,18 +31,21 @@ class Pipeline:
self.pop_size = config.neat.population.pop_size
self.species_controller = SpeciesController(config)
self.initialize_func = self.function_factory.create_initialize()
self.initialize_func = self.function_factory.create_initialize(self.N, self.C)
self.pop_nodes, self.pop_cons, self.input_idx, self.output_idx = self.initialize_func()
self.create_and_speciate = self.function_factory.create_update_speciate(self.N, self.C, self.S)
self.generation = 0
self.generation_time_list = []
self.species_controller.init_speciate(self.pop_nodes, self.pop_cons)
self.best_fitness = float('-inf')
self.best_genome = None
self.generation_timestamp = time.time()
self.evaluate_time = 0
def ask(self):
"""
Create a forward function for the population.
@@ -66,7 +69,9 @@ class Pipeline:
new_node_keys,
pre_spe_center_nodes, pre_spe_center_cons, pre_species_keys, new_species_key_start)
idx2specie, new_center_nodes, new_center_cons, new_species_keys = jax.device_get([idx2specie, new_center_nodes, new_center_cons, new_species_keys])
self.pop_nodes, self.pop_cons, idx2specie, new_center_nodes, new_center_cons, new_species_keys = \
jax.device_get([self.pop_nodes, self.pop_cons, idx2specie, new_center_nodes, new_center_cons, new_species_keys])
self.species_controller.tell(idx2specie, new_center_nodes, new_center_cons, new_species_keys, self.generation)
@@ -75,7 +80,12 @@ class Pipeline:
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config.neat.population.generation_limit):
forward_func = self.ask()
tic = time.time()
fitnesses = fitness_func(forward_func)
self.evaluate_time += time.time() - tic
assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!"
if analysis is not None:
if analysis == "default":
@@ -104,6 +114,7 @@ class Pipeline:
max_node_size = np.max(pop_node_sizes)
if max_node_size >= self.N:
self.N = int(self.N * self.expand_coe)
# self.C = int(self.C * self.expand_coe)
print(f"node expand to {self.N}!")
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C)
@@ -116,6 +127,7 @@ class Pipeline:
pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1)
max_con_size = np.max(pop_node_sizes)
if max_con_size >= self.C:
# self.N = int(self.N * self.expand_coe)
self.C = int(self.C * self.expand_coe)
print(f"connections expand to {self.C}!")
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C)
@@ -134,6 +146,7 @@ class Pipeline:
new_timestamp = time.time()
cost_time = new_timestamp - self.generation_timestamp
self.generation_time_list.append(cost_time)
self.generation_timestamp = new_timestamp
max_idx = np.argmax(fitnesses)