accelerate: unify jnp and np
This commit is contained in:
@@ -68,6 +68,7 @@ class SpeciesController:
|
||||
# calculate the distance between the representative and the population
|
||||
r_nodes, r_connections = species.representative
|
||||
distances = self.o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections)
|
||||
distances = jax.device_get(distances)
|
||||
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
|
||||
|
||||
new_representatives[sid] = min_idx
|
||||
@@ -80,7 +81,7 @@ class SpeciesController:
|
||||
if previous_species_list: # exist previous species
|
||||
rid_list = [new_representatives[sid] for sid in previous_species_list]
|
||||
res_pop_distance = [
|
||||
self.o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
|
||||
jax.device_get(self.o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections))
|
||||
for rid in rid_list
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user