fix bug in crossover: the child from two normal networks should always be normal.

This commit is contained in:
wls2002
2024-05-22 10:27:32 +08:00
parent d1559317d1
commit 6a37563696
11 changed files with 46 additions and 43 deletions

View File

@@ -69,7 +69,7 @@ class Pipeline:
pop_transformed
)
fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
# fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
alg_state = self.algorithm.tell(state.alg, fitnesses)
@@ -80,8 +80,10 @@ class Pipeline:
def auto_run(self, ini_state):
state = ini_state
print("start compile")
tic = time.time()
compiled_step = jax.jit(self.step).lower(ini_state).compile()
print(f"compile finished, cost time: {time.time() - tic:.6f}s", )
for _ in range(self.generation_limit):
self.generation_timestamp = time.time()
@@ -91,11 +93,6 @@ class Pipeline:
state, fitnesses = compiled_step(state)
fitnesses = jax.device_get(fitnesses)
for idx, fitnesses_i in enumerate(fitnesses):
if np.isnan(fitnesses_i):
print("Fitness is nan")
print(previous_pop[0][idx], previous_pop[1][idx])
assert False
self.analysis(state, previous_pop, fitnesses)