fix bug in restore genome.
This commit is contained in:
@@ -50,6 +50,11 @@ class Pipeline:
|
||||
self.fetch_data = lambda data: data
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
if isinstance(problem, RLEnv):
|
||||
assert not problem.record_episode, "record_episode must be False"
|
||||
elif isinstance(problem, FuncFit):
|
||||
assert not problem.return_data, "return_data must be False"
|
||||
|
||||
def setup(self, state=State()):
|
||||
print("initializing")
|
||||
@@ -90,6 +95,13 @@ class Pipeline:
|
||||
self.problem.evaluate, in_axes=(None, 0, None, 0)
|
||||
)(state, keys, self.algorithm.forward, pop_transformed)
|
||||
|
||||
# update population
|
||||
pop_nodes, pop_conns = jax.vmap(self.algorithm.restore, in_axes=(None, 0))(
|
||||
state, pop_transformed
|
||||
)
|
||||
state = state.update(pop_nodes=pop_nodes, pop_conns=pop_conns)
|
||||
|
||||
# update data for next generation
|
||||
data = self.fetch_data(raw_data)
|
||||
assert (
|
||||
data.ndim == 3
|
||||
@@ -119,9 +131,10 @@ class Pipeline:
|
||||
# replace nan with -inf
|
||||
fitnesses = jnp.where(jnp.isnan(fitnesses), -jnp.inf, fitnesses)
|
||||
|
||||
previous_pop = self.algorithm.ask(state)
|
||||
state = self.algorithm.tell(state, fitnesses)
|
||||
|
||||
return state.update(randkey=randkey), fitnesses
|
||||
return state.update(randkey=randkey), previous_pop, fitnesses
|
||||
|
||||
def auto_run(self, state):
|
||||
print("start compile")
|
||||
@@ -135,9 +148,7 @@ class Pipeline:
|
||||
|
||||
self.generation_timestamp = time.time()
|
||||
|
||||
previous_pop = self.algorithm.ask(state)
|
||||
|
||||
state, fitnesses = compiled_step(state)
|
||||
state, previous_pop, fitnesses = compiled_step(state)
|
||||
|
||||
fitnesses = jax.device_get(fitnesses)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user