prepare for experiment
This commit is contained in:
@@ -16,9 +16,9 @@ class Pipeline:
|
||||
Neat algorithm pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, config, seed=42):
|
||||
def __init__(self, config, function_factory, seed=42):
|
||||
self.time_dict = {}
|
||||
self.function_factory = FunctionFactory(config)
|
||||
self.function_factory = function_factory
|
||||
|
||||
self.randkey = jax.random.PRNGKey(seed)
|
||||
np.random.seed(seed)
|
||||
@@ -31,18 +31,21 @@ class Pipeline:
|
||||
self.pop_size = config.neat.population.pop_size
|
||||
|
||||
self.species_controller = SpeciesController(config)
|
||||
self.initialize_func = self.function_factory.create_initialize()
|
||||
self.initialize_func = self.function_factory.create_initialize(self.N, self.C)
|
||||
self.pop_nodes, self.pop_cons, self.input_idx, self.output_idx = self.initialize_func()
|
||||
|
||||
self.create_and_speciate = self.function_factory.create_update_speciate(self.N, self.C, self.S)
|
||||
|
||||
self.generation = 0
|
||||
self.generation_time_list = []
|
||||
self.species_controller.init_speciate(self.pop_nodes, self.pop_cons)
|
||||
|
||||
self.best_fitness = float('-inf')
|
||||
self.best_genome = None
|
||||
self.generation_timestamp = time.time()
|
||||
|
||||
self.evaluate_time = 0
|
||||
|
||||
def ask(self):
|
||||
"""
|
||||
Create a forward function for the population.
|
||||
@@ -66,7 +69,9 @@ class Pipeline:
|
||||
new_node_keys,
|
||||
pre_spe_center_nodes, pre_spe_center_cons, pre_species_keys, new_species_key_start)
|
||||
|
||||
idx2specie, new_center_nodes, new_center_cons, new_species_keys = jax.device_get([idx2specie, new_center_nodes, new_center_cons, new_species_keys])
|
||||
|
||||
self.pop_nodes, self.pop_cons, idx2specie, new_center_nodes, new_center_cons, new_species_keys = \
|
||||
jax.device_get([self.pop_nodes, self.pop_cons, idx2specie, new_center_nodes, new_center_cons, new_species_keys])
|
||||
|
||||
self.species_controller.tell(idx2specie, new_center_nodes, new_center_cons, new_species_keys, self.generation)
|
||||
|
||||
@@ -75,7 +80,12 @@ class Pipeline:
|
||||
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
||||
for _ in range(self.config.neat.population.generation_limit):
|
||||
forward_func = self.ask()
|
||||
|
||||
tic = time.time()
|
||||
fitnesses = fitness_func(forward_func)
|
||||
self.evaluate_time += time.time() - tic
|
||||
|
||||
assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!"
|
||||
|
||||
if analysis is not None:
|
||||
if analysis == "default":
|
||||
@@ -104,6 +114,7 @@ class Pipeline:
|
||||
max_node_size = np.max(pop_node_sizes)
|
||||
if max_node_size >= self.N:
|
||||
self.N = int(self.N * self.expand_coe)
|
||||
# self.C = int(self.C * self.expand_coe)
|
||||
print(f"node expand to {self.N}!")
|
||||
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C)
|
||||
|
||||
@@ -116,6 +127,7 @@ class Pipeline:
|
||||
pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1)
|
||||
max_con_size = np.max(pop_node_sizes)
|
||||
if max_con_size >= self.C:
|
||||
# self.N = int(self.N * self.expand_coe)
|
||||
self.C = int(self.C * self.expand_coe)
|
||||
print(f"connections expand to {self.C}!")
|
||||
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C)
|
||||
@@ -134,6 +146,7 @@ class Pipeline:
|
||||
|
||||
new_timestamp = time.time()
|
||||
cost_time = new_timestamp - self.generation_timestamp
|
||||
self.generation_time_list.append(cost_time)
|
||||
self.generation_timestamp = new_timestamp
|
||||
|
||||
max_idx = np.argmax(fitnesses)
|
||||
|
||||
Reference in New Issue
Block a user