accelerate: unify jnp and np

This commit is contained in:
wls2002
2023-05-08 00:46:48 +08:00
parent cf47c5bb38
commit c705b5cfe2
3 changed files with 13 additions and 13 deletions

View File

@@ -68,6 +68,7 @@ class SpeciesController:
# calculate the distance between the representative and the population
r_nodes, r_connections = species.representative
distances = self.o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections)
distances = jax.device_get(distances)
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
new_representatives[sid] = min_idx
@@ -80,7 +81,7 @@ class SpeciesController:
if previous_species_list: # exist previous species
rid_list = [new_representatives[sid] for sid in previous_species_list]
res_pop_distance = [
self.o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
jax.device_get(self.o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections))
for rid in rid_list
]