debuging
This commit is contained in:
@@ -67,7 +67,7 @@ class SpeciesController:
|
||||
for sid, species in self.species.items():
|
||||
# calculate the distance between the representative and the population
|
||||
r_nodes, r_connections = species.representative
|
||||
distances = self.o2m_distance_func(r_nodes, r_connections, pop_nodes, pop_connections)
|
||||
distances = self.o2m_distance_wrapper(r_nodes, r_connections, pop_nodes, pop_connections)
|
||||
distances = jax.device_get(distances) # fetch the data from gpu
|
||||
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
|
||||
|
||||
@@ -82,7 +82,7 @@ class SpeciesController:
|
||||
rid_list = [new_representatives[sid] for sid in previous_species_list]
|
||||
res_pop_distance = [
|
||||
jax.device_get(
|
||||
self.o2m_distance_func(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
|
||||
self.o2m_distance_wrapper(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
|
||||
)
|
||||
for rid in rid_list
|
||||
]
|
||||
@@ -110,7 +110,7 @@ class SpeciesController:
|
||||
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
|
||||
|
||||
distances = [
|
||||
self.o2o_distance_func(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
|
||||
self.o2o_distance_wrapper(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
|
||||
for r in rid
|
||||
]
|
||||
distances = np.array(distances)
|
||||
@@ -267,6 +267,36 @@ class SpeciesController:
|
||||
|
||||
return crossover_pair
|
||||
|
||||
def o2m_distance_wrapper(self, r_nodes, r_connections, pop_nodes, pop_connections):
|
||||
# distances = self.o2m_distance_func(r_nodes, r_connections, pop_nodes, pop_connections)
|
||||
# for d in distances:
|
||||
# if np.isnan(d):
|
||||
# print("fuck!!!!!!!!!!!!!!")
|
||||
# print(distances)
|
||||
# assert False
|
||||
# return distances
|
||||
distances = []
|
||||
for nodes, connections in zip(pop_nodes, pop_connections):
|
||||
d = self.o2o_distance_func(r_nodes, r_connections, nodes, connections)
|
||||
if np.isnan(d) or d > 20:
|
||||
np.save("too_large_distance_r_nodes.npy", r_nodes)
|
||||
np.save("too_large_distance_r_connections.npy", r_connections)
|
||||
np.save("too_large_distance_nodes", nodes)
|
||||
np.save("too_large_distance_connections.npy", connections)
|
||||
d = self.o2o_distance_func(r_nodes, r_connections, nodes, connections)
|
||||
assert False
|
||||
distances.append(d)
|
||||
distances = np.stack(distances, axis=0)
|
||||
# print(distances)
|
||||
return distances
|
||||
|
||||
def o2o_distance_wrapper(self, *keys):
|
||||
d = self.o2o_distance_func(*keys)
|
||||
if np.isnan(d):
|
||||
print("fuck!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||
assert False
|
||||
return d
|
||||
|
||||
|
||||
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user