debug-branch

This commit is contained in:
wls2002
2023-05-06 21:04:28 +08:00
parent 14fed83193
commit a85e6eba78
20 changed files with 1719 additions and 233 deletions

View File

@@ -1,10 +1,9 @@
from typing import List, Tuple, Dict, Union
from itertools import count
import jax
import numpy as np
from numpy.typing import NDArray
from .genome import distance
from .genome.numpy import distance
class Species(object):
@@ -46,10 +45,6 @@ class SpeciesController:
self.species_idxer = count(0)
self.species: Dict[int, Species] = {} # species_id -> species
self.o2m_distance_func = jax.vmap(distance, in_axes=(None, None, 0, 0)) # one to many
# self.o2o_distance_func = np_distance # one to one
self.o2o_distance_func = distance
def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None:
"""
:param pop_nodes:
@@ -67,8 +62,7 @@ class SpeciesController:
for sid, species in self.species.items():
# calculate the distance between the representative and the population
r_nodes, r_connections = species.representative
distances = self.o2m_distance_wrapper(r_nodes, r_connections, pop_nodes, pop_connections)
distances = jax.device_get(distances) # fetch the data from gpu
distances = o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections)
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
new_representatives[sid] = min_idx
@@ -81,9 +75,7 @@ class SpeciesController:
if previous_species_list: # exist previous species
rid_list = [new_representatives[sid] for sid in previous_species_list]
res_pop_distance = [
jax.device_get(
self.o2m_distance_wrapper(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
)
o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
for rid in rid_list
]
@@ -110,7 +102,7 @@ class SpeciesController:
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
distances = [
self.o2o_distance_wrapper(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
for r in rid
]
distances = np.array(distances)
@@ -267,36 +259,6 @@ class SpeciesController:
return crossover_pair
def o2m_distance_wrapper(self, r_nodes, r_connections, pop_nodes, pop_connections):
# distances = self.o2m_distance_func(r_nodes, r_connections, pop_nodes, pop_connections)
# for d in distances:
# if np.isnan(d):
# print("fuck!!!!!!!!!!!!!!")
# print(distances)
# assert False
# return distances
distances = []
for nodes, connections in zip(pop_nodes, pop_connections):
d = self.o2o_distance_func(r_nodes, r_connections, nodes, connections)
if np.isnan(d) or d > 20:
np.save("too_large_distance_r_nodes.npy", r_nodes)
np.save("too_large_distance_r_connections.npy", r_connections)
np.save("too_large_distance_nodes", nodes)
np.save("too_large_distance_connections.npy", connections)
d = self.o2o_distance_func(r_nodes, r_connections, nodes, connections)
assert False
distances.append(d)
distances = np.stack(distances, axis=0)
# print(distances)
return distances
def o2o_distance_wrapper(self, *keys):
d = self.o2o_distance_func(*keys)
if np.isnan(d):
print("fuck!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
assert False
return d
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
"""
@@ -351,3 +313,12 @@ def sort_element_with_fitnesses(members: List[int], fitnesses: List[float]) \
sorted_members = [item[0] for item in sorted_combined]
sorted_fitnesses = [item[1] for item in sorted_combined]
return sorted_members, sorted_fitnesses
def o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections):
distances = []
for nodes, connections in zip(pop_nodes, pop_connections):
d = distance(r_nodes, r_connections, nodes, connections)
distances.append(d)
distances = np.stack(distances, axis=0)
return distances