Merge branch 'main' into advance
This commit is contained in:
@@ -69,7 +69,7 @@ class Pipeline:
|
||||
pop_transformed
|
||||
)
|
||||
|
||||
fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
|
||||
# fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
|
||||
|
||||
alg_state = self.algorithm.tell(state.alg, fitnesses)
|
||||
|
||||
@@ -80,9 +80,12 @@ class Pipeline:
|
||||
|
||||
def auto_run(self, ini_state):
|
||||
state = ini_state
|
||||
print("start compile")
|
||||
tic = time.time()
|
||||
compiled_step = jax.jit(self.step).lower(ini_state).compile()
|
||||
|
||||
for w in range(self.generation_limit):
|
||||
print(f"compile finished, cost time: {time.time() - tic:.6f}s", )
|
||||
for _ in range(self.generation_limit):
|
||||
|
||||
self.generation_timestamp = time.time()
|
||||
|
||||
@@ -92,11 +95,6 @@ class Pipeline:
|
||||
state, fitnesses = compiled_step(state)
|
||||
|
||||
fitnesses = jax.device_get(fitnesses)
|
||||
for idx, fitnesses_i in enumerate(fitnesses):
|
||||
if np.isnan(fitnesses_i):
|
||||
print("Fitness is nan")
|
||||
print(previous_pop[0][idx], previous_pop[1][idx])
|
||||
assert False
|
||||
|
||||
self.analysis(state, previous_pop, fitnesses)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user