try to accelerate the speed of speciate

This commit is contained in:
wls2002
2023-05-08 18:41:19 +08:00
parent 8653f49826
commit ee6bb01eff
4 changed files with 12 additions and 13 deletions

View File

@@ -76,11 +76,13 @@ class SpeciesController:
new_representatives = {}
new_members = {}
for sid, species in self.species.items():
# calculate the distance between the representative and the population
r_nodes, r_connections = species.representative
distances = o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections)
distances = jax.device_get(distances)
total_distances = jax.device_get([
o2m_distance(*self.species[sid].representative, pop_nodes, pop_connections)
for sid in previous_species_list
])
for i, sid in enumerate(previous_species_list):
distances = total_distances[i]
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
new_representatives[sid] = min_idx