reduce the use of device_get in speciate

This commit is contained in:
wls2002
2023-05-08 15:58:09 +08:00
parent 91206c796f
commit dde338696f
2 changed files with 12 additions and 54 deletions

View File

@@ -92,10 +92,10 @@ class SpeciesController:
# First, fast match the population to previous species
if previous_species_list: # exist previous species
rid_list = [new_representatives[sid] for sid in previous_species_list]
res_pop_distance = [
jax.device_get(o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections))
res_pop_distance = jax.device_get([
o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
for rid in rid_list
]
])
pop_res_distance = np.stack(res_pop_distance, axis=0).T
for i in range(pop_res_distance.shape[0]):
@@ -118,10 +118,10 @@ class SpeciesController:
if len(new_representatives) != 0:
# the representatives of new species
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
distances = [
jax.device_get(o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r]))
distances = jax.device_get([
o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
for r in rid
]
])
distances = np.array(distances)
min_idx = np.argmin(distances)
min_val = distances[min_idx]