This commit is contained in:
wls2002
2023-05-06 18:33:30 +08:00
parent 73ac1bcfe0
commit 14fed83193
10 changed files with 206 additions and 22 deletions

View File

@@ -67,7 +67,7 @@ class SpeciesController:
for sid, species in self.species.items():
# calculate the distance between the representative and the population
r_nodes, r_connections = species.representative
distances = self.o2m_distance_func(r_nodes, r_connections, pop_nodes, pop_connections)
distances = self.o2m_distance_wrapper(r_nodes, r_connections, pop_nodes, pop_connections)
distances = jax.device_get(distances) # fetch the data from gpu
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
@@ -82,7 +82,7 @@ class SpeciesController:
rid_list = [new_representatives[sid] for sid in previous_species_list]
res_pop_distance = [
jax.device_get(
self.o2m_distance_func(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
self.o2m_distance_wrapper(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
)
for rid in rid_list
]
@@ -110,7 +110,7 @@ class SpeciesController:
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
distances = [
self.o2o_distance_func(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
self.o2o_distance_wrapper(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
for r in rid
]
distances = np.array(distances)
@@ -267,6 +267,36 @@ class SpeciesController:
return crossover_pair
def o2m_distance_wrapper(self, r_nodes, r_connections, pop_nodes, pop_connections):
# distances = self.o2m_distance_func(r_nodes, r_connections, pop_nodes, pop_connections)
# for d in distances:
# if np.isnan(d):
# print("fuck!!!!!!!!!!!!!!")
# print(distances)
# assert False
# return distances
distances = []
for nodes, connections in zip(pop_nodes, pop_connections):
d = self.o2o_distance_func(r_nodes, r_connections, nodes, connections)
if np.isnan(d) or d > 20:
np.save("too_large_distance_r_nodes.npy", r_nodes)
np.save("too_large_distance_r_connections.npy", r_connections)
np.save("too_large_distance_nodes", nodes)
np.save("too_large_distance_connections.npy", connections)
d = self.o2o_distance_func(r_nodes, r_connections, nodes, connections)
assert False
distances.append(d)
distances = np.stack(distances, axis=0)
# print(distances)
return distances
def o2o_distance_wrapper(self, *keys):
d = self.o2o_distance_func(*keys)
if np.isnan(d):
print("fuck!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
assert False
return d
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
"""