diff --git a/.gitignore b/.gitignore index e80a817..6e61a05 100644 --- a/.gitignore +++ b/.gitignore @@ -118,4 +118,6 @@ cython_debug/ tutorials/.ipynb_checkpoints/* docs/_build/* -examples/func_fit/evolving_state.pkl \ No newline at end of file +examples/func_fit/evolving_state.pkl + +src/evox/* \ No newline at end of file diff --git a/examples/with_evox/walker2d_evox.py b/examples/with_evox/walker2d_evox.py index 263ba1c..4df46f7 100644 --- a/examples/with_evox/walker2d_evox.py +++ b/examples/with_evox/walker2d_evox.py @@ -48,7 +48,7 @@ def nan2inf(x): workflow = workflows.StdWorkflow( algorithm=evox_algorithm, problem=problem, - candidate_transforms=[jax.jit(jax.vmap(evox_algorithm.transform))], + solution_transforms=[jax.jit(jax.vmap(evox_algorithm.transform))], fitness_transforms=[nan2inf], monitors=[monitor], opt_direction="max", @@ -62,5 +62,5 @@ state = workflow.enable_multi_devices(state) # run the workflow for 100 steps for i in range(100): - train_info, state = workflow.step(state) + state = workflow.step(state) monitor.show() diff --git a/src/tensorneat/common/evox_adaptors/tensorneat_monitor.py b/src/tensorneat/common/evox_adaptors/tensorneat_monitor.py index 26c7eda..73aba98 100644 --- a/src/tensorneat/common/evox_adaptors/tensorneat_monitor.py +++ b/src/tensorneat/common/evox_adaptors/tensorneat_monitor.py @@ -44,22 +44,35 @@ class TensorNEATMonitor(Monitor): if not os.path.exists(self.genome_dir): os.makedirs(self.genome_dir) + def clear_history(self): + self.alg_state: TensorNEATState = None + self.fitness = None + self.best_fitness = -np.inf + self.best_genome = None + def hooks(self): return ["pre_tell"] - def pre_tell(self, state: EvoXState, cand_sol, transformed_cand_sol, fitness, transformed_fitness): + def pre_tell(self, monitor_state, workflow_state, transformed_fitness): io_callback( self.store_info, None, - state, + workflow_state, transformed_fitness, ) + return monitor_state def store_info(self, state: EvoXState, fitness): self.alg_state: TensorNEATState = state.query_state("algorithm").alg_state self.fitness = jax.device_get(fitness) def show(self): + io_callback( + self._show, + None + ) + + def _show(self): pop = self.tensorneat_algorithm.ask(self.alg_state) generation = int(self.alg_state.generation)