reduce the use of device_get in speciate
This commit is contained in:
@@ -92,10 +92,10 @@ class SpeciesController:
|
||||
# First, fast match the population to previous species
|
||||
if previous_species_list: # exist previous species
|
||||
rid_list = [new_representatives[sid] for sid in previous_species_list]
|
||||
res_pop_distance = [
|
||||
jax.device_get(o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections))
|
||||
res_pop_distance = jax.device_get([
|
||||
o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
|
||||
for rid in rid_list
|
||||
]
|
||||
])
|
||||
|
||||
pop_res_distance = np.stack(res_pop_distance, axis=0).T
|
||||
for i in range(pop_res_distance.shape[0]):
|
||||
@@ -118,10 +118,10 @@ class SpeciesController:
|
||||
if len(new_representatives) != 0:
|
||||
# the representatives of new species
|
||||
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
|
||||
distances = [
|
||||
jax.device_get(o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r]))
|
||||
distances = jax.device_get([
|
||||
o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
|
||||
for r in rid
|
||||
]
|
||||
])
|
||||
distances = np.array(distances)
|
||||
min_idx = np.argmin(distances)
|
||||
min_val = distances[min_idx]
|
||||
|
||||
Reference in New Issue
Block a user