update Evox releated for EvoX-v0.9.1

This commit is contained in:
wls2002
2025-03-26 10:05:49 +08:00
parent 55f1c626d3
commit f017e8357a
3 changed files with 20 additions and 5 deletions

4
.gitignore vendored
View File

@@ -118,4 +118,6 @@ cython_debug/
tutorials/.ipynb_checkpoints/* tutorials/.ipynb_checkpoints/*
docs/_build/* docs/_build/*
examples/func_fit/evolving_state.pkl examples/func_fit/evolving_state.pkl
src/evox/*

View File

@@ -48,7 +48,7 @@ def nan2inf(x):
workflow = workflows.StdWorkflow( workflow = workflows.StdWorkflow(
algorithm=evox_algorithm, algorithm=evox_algorithm,
problem=problem, problem=problem,
candidate_transforms=[jax.jit(jax.vmap(evox_algorithm.transform))], solution_transforms=[jax.jit(jax.vmap(evox_algorithm.transform))],
fitness_transforms=[nan2inf], fitness_transforms=[nan2inf],
monitors=[monitor], monitors=[monitor],
opt_direction="max", opt_direction="max",
@@ -62,5 +62,5 @@ state = workflow.enable_multi_devices(state)
# run the workflow for 100 steps # run the workflow for 100 steps
for i in range(100): for i in range(100):
train_info, state = workflow.step(state) state = workflow.step(state)
monitor.show() monitor.show()

View File

@@ -44,22 +44,35 @@ class TensorNEATMonitor(Monitor):
if not os.path.exists(self.genome_dir): if not os.path.exists(self.genome_dir):
os.makedirs(self.genome_dir) os.makedirs(self.genome_dir)
def clear_history(self):
self.alg_state: TensorNEATState = None
self.fitness = None
self.best_fitness = -np.inf
self.best_genome = None
def hooks(self): def hooks(self):
return ["pre_tell"] return ["pre_tell"]
def pre_tell(self, state: EvoXState, cand_sol, transformed_cand_sol, fitness, transformed_fitness): def pre_tell(self, monitor_state, workflow_state, transformed_fitness):
io_callback( io_callback(
self.store_info, self.store_info,
None, None,
state, workflow_state,
transformed_fitness, transformed_fitness,
) )
return monitor_state
def store_info(self, state: EvoXState, fitness): def store_info(self, state: EvoXState, fitness):
self.alg_state: TensorNEATState = state.query_state("algorithm").alg_state self.alg_state: TensorNEATState = state.query_state("algorithm").alg_state
self.fitness = jax.device_get(fitness) self.fitness = jax.device_get(fitness)
def show(self): def show(self):
io_callback(
self._show,
None
)
def _show(self):
pop = self.tensorneat_algorithm.ask(self.alg_state) pop = self.tensorneat_algorithm.ask(self.alg_state)
generation = int(self.alg_state.generation) generation = int(self.alg_state.generation)