change fitness from list to array
optimize the code of reproduction.
This commit is contained in:
@@ -53,9 +53,9 @@ class Pipeline:
|
|||||||
|
|
||||||
self.species_controller.update_species_fitnesses(fitnesses)
|
self.species_controller.update_species_fitnesses(fitnesses)
|
||||||
|
|
||||||
crossover_pair = self.species_controller.reproduce(self.generation)
|
winner_part, loser_part, elite_mask = self.species_controller.reproduce(fitnesses, self.generation)
|
||||||
|
|
||||||
self.update_next_generation(crossover_pair)
|
self.update_next_generation(winner_part, loser_part, elite_mask)
|
||||||
|
|
||||||
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation,
|
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation,
|
||||||
self.o2o_distance, self.o2m_distance)
|
self.o2o_distance, self.o2m_distance)
|
||||||
@@ -82,36 +82,29 @@ class Pipeline:
|
|||||||
print("Generation limit reached!")
|
print("Generation limit reached!")
|
||||||
return self.best_genome
|
return self.best_genome
|
||||||
|
|
||||||
def update_next_generation(self, crossover_pair: List[Union[int, Tuple[int, int]]]) -> None:
|
def update_next_generation(self, winner_part, loser_part, elite_mask) -> None:
|
||||||
"""
|
"""
|
||||||
create the next generation
|
create next generation
|
||||||
:param crossover_pair: created from self.reproduce()
|
:param winner_part:
|
||||||
|
:param loser_part:
|
||||||
|
:param elite_mask:
|
||||||
|
:return:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert self.pop_nodes.shape[0] == self.pop_size
|
assert self.pop_nodes.shape[0] == self.pop_size
|
||||||
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
|
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
|
||||||
|
|
||||||
# crossover
|
|
||||||
# prepare elitism mask and crossover pair
|
|
||||||
elitism_mask = np.full(self.pop_size, False)
|
|
||||||
|
|
||||||
for i, pair in enumerate(crossover_pair):
|
|
||||||
if not isinstance(pair, tuple): # elitism
|
|
||||||
elitism_mask[i] = True
|
|
||||||
crossover_pair[i] = (pair, pair)
|
|
||||||
crossover_pair = np.array(crossover_pair)
|
|
||||||
|
|
||||||
crossover_rand_keys = jax.random.split(k1, self.pop_size)
|
crossover_rand_keys = jax.random.split(k1, self.pop_size)
|
||||||
mutate_rand_keys = jax.random.split(k2, self.pop_size)
|
mutate_rand_keys = jax.random.split(k2, self.pop_size)
|
||||||
|
|
||||||
# batch crossover
|
# batch crossover
|
||||||
wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes
|
wpn = self.pop_nodes[winner_part] # winner pop nodes
|
||||||
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
|
wpc = self.pop_connections[winner_part] # winner pop connections
|
||||||
lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
|
lpn = self.pop_nodes[loser_part] # loser pop nodes
|
||||||
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
|
lpc = self.pop_connections[loser_part] # loser pop connections
|
||||||
|
|
||||||
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn,
|
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn,
|
||||||
lpc) # new pop nodes, new pop connections
|
lpc) # new pop nodes, new pop connections
|
||||||
|
|
||||||
# mutate
|
# mutate
|
||||||
new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size)
|
new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size)
|
||||||
|
|
||||||
@@ -120,8 +113,8 @@ class Pipeline:
|
|||||||
# elitism don't mutate
|
# elitism don't mutate
|
||||||
npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc])
|
npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc])
|
||||||
|
|
||||||
self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn)
|
self.pop_nodes = np.where(elite_mask[:, None, None], npn, m_npn)
|
||||||
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
|
self.pop_connections = np.where(elite_mask[:, None, None, None], npc, m_npc)
|
||||||
|
|
||||||
def expand(self):
|
def expand(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user