Merge branch 'main' into advance

This commit is contained in:
WLS2002
2024-05-24 19:42:03 +08:00
committed by GitHub
17 changed files with 156 additions and 82 deletions

View File

@@ -69,7 +69,7 @@ class Pipeline:
pop_transformed
)
fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
# fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
alg_state = self.algorithm.tell(state.alg, fitnesses)
@@ -80,9 +80,12 @@ class Pipeline:
def auto_run(self, ini_state):
state = ini_state
print("start compile")
tic = time.time()
compiled_step = jax.jit(self.step).lower(ini_state).compile()
for w in range(self.generation_limit):
print(f"compile finished, cost time: {time.time() - tic:.6f}s", )
for _ in range(self.generation_limit):
self.generation_timestamp = time.time()
@@ -92,11 +95,6 @@ class Pipeline:
state, fitnesses = compiled_step(state)
fitnesses = jax.device_get(fitnesses)
for idx, fitnesses_i in enumerate(fitnesses):
if np.isnan(fitnesses_i):
print("Fitness is nan")
print(previous_pop[0][idx], previous_pop[1][idx])
assert False
self.analysis(state, previous_pop, fitnesses)