fix bug in restore genome.

This commit is contained in:
wls2002
2024-05-31 19:43:14 +08:00
parent bc8267bad0
commit d6e9ff5d9a
4 changed files with 62 additions and 20 deletions

View File

@@ -50,6 +50,11 @@ class Pipeline:
self.fetch_data = lambda data: data
else:
raise NotImplementedError
else:
if isinstance(problem, RLEnv):
assert not problem.record_episode, "record_episode must be False"
elif isinstance(problem, FuncFit):
assert not problem.return_data, "return_data must be False"
def setup(self, state=State()):
print("initializing")
@@ -90,6 +95,13 @@ class Pipeline:
self.problem.evaluate, in_axes=(None, 0, None, 0)
)(state, keys, self.algorithm.forward, pop_transformed)
# update population
pop_nodes, pop_conns = jax.vmap(self.algorithm.restore, in_axes=(None, 0))(
state, pop_transformed
)
state = state.update(pop_nodes=pop_nodes, pop_conns=pop_conns)
# update data for next generation
data = self.fetch_data(raw_data)
assert (
data.ndim == 3
@@ -119,9 +131,10 @@ class Pipeline:
# replace nan with -inf
fitnesses = jnp.where(jnp.isnan(fitnesses), -jnp.inf, fitnesses)
previous_pop = self.algorithm.ask(state)
state = self.algorithm.tell(state, fitnesses)
return state.update(randkey=randkey), fitnesses
return state.update(randkey=randkey), previous_pop, fitnesses
def auto_run(self, state):
print("start compile")
@@ -135,9 +148,7 @@ class Pipeline:
self.generation_timestamp = time.time()
previous_pop = self.algorithm.ask(state)
state, fitnesses = compiled_step(state)
state, previous_pop, fitnesses = compiled_step(state)
fitnesses = jax.device_get(fitnesses)