change fitness from list to array

optimize the code of reproduction.
This commit is contained in:
wls2002
2023-05-11 08:14:58 +08:00
parent b271a56827
commit 299ff1f8f1
4 changed files with 37 additions and 34 deletions

View File

@@ -13,7 +13,7 @@ class Species(object):
self.created = generation
self.last_improved = generation
self.representative: Tuple[NDArray, NDArray] = (None, None) # (nodes, connections)
self.members: List[int] = [] # idx in pop_nodes, pop_connections,
self.members: NDArray = None # idx in pop_nodes, pop_connections,
self.fitness = None
self.member_fitnesses = None
self.adjusted_fitness = None
@@ -24,7 +24,7 @@ class Species(object):
self.members = members
def get_fitnesses(self, fitnesses):
return [fitnesses[m] for m in self.members]
return fitnesses[self.members]
class SpeciesController:
@@ -55,7 +55,7 @@ class SpeciesController:
pop_size = pop_nodes.shape[0]
species_id = next(self.species_idxer)
s = Species(species_id, 0)
members = list(range(pop_size))
members = np.array(list(range(pop_size)))
s.update((pop_nodes[0], pop_connections[0]), members)
self.species[species_id] = s
@@ -81,6 +81,7 @@ class SpeciesController:
for sid in previous_species_list
])
# TODO: Use jit to wrapper function find_min_with_mask to accelerate this part
for i, sid in enumerate(previous_species_list):
distances = total_distances[i]
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
@@ -145,7 +146,7 @@ class SpeciesController:
s = Species(sid, generation)
self.species[sid] = s
members = new_members[sid]
members = np.array(new_members[sid])
s.update((pop_nodes[rid], pop_connections[rid]), members)
def update_species_fitnesses(self, fitnesses):
@@ -195,9 +196,10 @@ class SpeciesController:
result.append((sid, s, is_stagnant))
return result
def reproduce(self, generation: int) -> List[Union[int, Tuple[int, int]]]:
def reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]:
"""
code modified from neat-python!
:param fitnesses:
:param generation:
:return: crossover_pair for next generation.
# int -> idx in the pop_nodes, pop_connections of elitism
@@ -208,21 +210,22 @@ class SpeciesController:
# The average adjusted fitness scheme (normalized to the interval
# [0, 1]) allows the use of negative fitness values without
# interfering with the shared fitness scheme.
all_fitnesses = []
min_fitness = np.inf
max_fitness = -np.inf
remaining_species = []
for stag_sid, stag_s, stagnant in self.stagnation(generation):
if not stagnant:
all_fitnesses.extend(stag_s.member_fitnesses)
min_fitness = min(min_fitness, np.min(stag_s.member_fitnesses))
max_fitness = max(max_fitness, np.max(stag_s.member_fitnesses))
remaining_species.append(stag_s)
# No species left.
if not remaining_species:
self.species = {}
return []
assert remaining_species
# Compute each species' member size in the next generation.
min_fitness = min(all_fitnesses)
max_fitness = max(all_fitnesses)
# Do not allow the fitness range to be zero, as we divide by it below.
# TODO: The ``1.0`` below is rather arbitrary, and should be configurable.
fitness_range = max(1.0, max_fitness - min_fitness)
@@ -242,21 +245,23 @@ class SpeciesController:
# int -> idx in the pop_nodes, pop_connections of elitism
# (int, int) -> the father and mother idx to be crossover
crossover_pair: List[Union[int, Tuple[int, int]]] = []
part1, part2, elite_mask = [], [], []
for spawn, s in zip(spawn_amounts, remaining_species):
assert spawn >= self.genome_elitism
# retain remain species to next generation
old_members, fitnesses = s.members, s.member_fitnesses
old_members, member_fitnesses = s.members, s.member_fitnesses
s.members = []
self.species[s.key] = s
# add elitism genomes to next generation
sorted_members, sorted_fitnesses = sort_element_with_fitnesses(old_members, fitnesses)
sorted_members, sorted_fitnesses = sort_element_with_fitnesses(old_members, member_fitnesses)
if self.genome_elitism > 0:
for m in sorted_members[:self.genome_elitism]:
crossover_pair.append(m)
part1.append(m)
part2.append(m)
elite_mask.append(True)
spawn -= 1
if spawn <= 0:
@@ -269,15 +274,16 @@ class SpeciesController:
sorted_members = sorted_members[:repro_cutoff]
list_idx1, list_idx2 = np.random.choice(len(sorted_members), size=(2, spawn), replace=True)
for c1, c2 in zip(list_idx1, list_idx2):
idx1, fitness1 = sorted_members[c1], sorted_fitnesses[c1]
idx2, fitness2 = sorted_members[c2], sorted_fitnesses[c2]
if fitness1 >= fitness2:
crossover_pair.append((idx1, idx2))
else:
crossover_pair.append((idx2, idx1))
part1.extend(sorted_members[list_idx1])
part2.extend(sorted_members[list_idx2])
elite_mask.extend([False] * spawn)
return crossover_pair
part1_fitness, part2_fitness = fitnesses[part1], fitnesses[part2]
is_part1_win = part1_fitness >= part2_fitness
winner_part = np.where(is_part1_win, part1, part2)
loser_part = np.where(is_part1_win, part2, part1)
return winner_part, loser_part, np.array(elite_mask)
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
@@ -326,10 +332,7 @@ def find_min_with_mask(arr: NDArray, mask: NDArray) -> int:
return min_idx
def sort_element_with_fitnesses(members: List[int], fitnesses: List[float]) \
-> Tuple[List[int], List[float]]:
combined = zip(members, fitnesses)
sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True)
sorted_members = [item[0] for item in sorted_combined]
sorted_fitnesses = [item[1] for item in sorted_combined]
return sorted_members, sorted_fitnesses
def sort_element_with_fitnesses(members: NDArray, fitnesses: NDArray) \
-> Tuple[NDArray, NDArray]:
sorted_idx = np.argsort(fitnesses)[::-1]
return members[sorted_idx], fitnesses[sorted_idx]

View File

@@ -25,7 +25,7 @@ from problems import Sin, Xor, DIY
@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
def main():
config = Configer.load_config()
config.neat.population.pop_size = 50
# config.neat.population.pop_size = 50
problem = Xor()
# problem = Sin()
# problem = DIY(func=lambda x: (np.sin(x) + np.exp(x) - x ** 2) / (np.cos(x) + np.sqrt(x)) - np.log(x + 1))

View File

@@ -19,7 +19,7 @@ class FunctionFittingProblem(Problem):
outs = pop_batch_forward(self.inputs)
outs = jax.device_get(outs)
fitnesses = -np.mean((self.target - outs) ** 2, axis=(1, 2))
return fitnesses.tolist()
return fitnesses
def draw(self, batch_func):
outs = batch_func(self.inputs)

View File

@@ -13,7 +13,7 @@
"fitness_criterion": "max",
"fitness_threshold": -0.001,
"generation_limit": 1000,
"pop_size": 30,
"pop_size": 1000,
"reset_on_extinction": "False"
},
"gene": {