update Evox releated for EvoX-v0.9.1
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -118,4 +118,6 @@ cython_debug/
|
|||||||
|
|
||||||
tutorials/.ipynb_checkpoints/*
|
tutorials/.ipynb_checkpoints/*
|
||||||
docs/_build/*
|
docs/_build/*
|
||||||
examples/func_fit/evolving_state.pkl
|
examples/func_fit/evolving_state.pkl
|
||||||
|
|
||||||
|
src/evox/*
|
||||||
@@ -48,7 +48,7 @@ def nan2inf(x):
|
|||||||
workflow = workflows.StdWorkflow(
|
workflow = workflows.StdWorkflow(
|
||||||
algorithm=evox_algorithm,
|
algorithm=evox_algorithm,
|
||||||
problem=problem,
|
problem=problem,
|
||||||
candidate_transforms=[jax.jit(jax.vmap(evox_algorithm.transform))],
|
solution_transforms=[jax.jit(jax.vmap(evox_algorithm.transform))],
|
||||||
fitness_transforms=[nan2inf],
|
fitness_transforms=[nan2inf],
|
||||||
monitors=[monitor],
|
monitors=[monitor],
|
||||||
opt_direction="max",
|
opt_direction="max",
|
||||||
@@ -62,5 +62,5 @@ state = workflow.enable_multi_devices(state)
|
|||||||
|
|
||||||
# run the workflow for 100 steps
|
# run the workflow for 100 steps
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
train_info, state = workflow.step(state)
|
state = workflow.step(state)
|
||||||
monitor.show()
|
monitor.show()
|
||||||
|
|||||||
@@ -44,22 +44,35 @@ class TensorNEATMonitor(Monitor):
|
|||||||
if not os.path.exists(self.genome_dir):
|
if not os.path.exists(self.genome_dir):
|
||||||
os.makedirs(self.genome_dir)
|
os.makedirs(self.genome_dir)
|
||||||
|
|
||||||
|
def clear_history(self):
|
||||||
|
self.alg_state: TensorNEATState = None
|
||||||
|
self.fitness = None
|
||||||
|
self.best_fitness = -np.inf
|
||||||
|
self.best_genome = None
|
||||||
|
|
||||||
def hooks(self):
|
def hooks(self):
|
||||||
return ["pre_tell"]
|
return ["pre_tell"]
|
||||||
|
|
||||||
def pre_tell(self, state: EvoXState, cand_sol, transformed_cand_sol, fitness, transformed_fitness):
|
def pre_tell(self, monitor_state, workflow_state, transformed_fitness):
|
||||||
io_callback(
|
io_callback(
|
||||||
self.store_info,
|
self.store_info,
|
||||||
None,
|
None,
|
||||||
state,
|
workflow_state,
|
||||||
transformed_fitness,
|
transformed_fitness,
|
||||||
)
|
)
|
||||||
|
return monitor_state
|
||||||
|
|
||||||
def store_info(self, state: EvoXState, fitness):
|
def store_info(self, state: EvoXState, fitness):
|
||||||
self.alg_state: TensorNEATState = state.query_state("algorithm").alg_state
|
self.alg_state: TensorNEATState = state.query_state("algorithm").alg_state
|
||||||
self.fitness = jax.device_get(fitness)
|
self.fitness = jax.device_get(fitness)
|
||||||
|
|
||||||
def show(self):
|
def show(self):
|
||||||
|
io_callback(
|
||||||
|
self._show,
|
||||||
|
None
|
||||||
|
)
|
||||||
|
|
||||||
|
def _show(self):
|
||||||
pop = self.tensorneat_algorithm.ask(self.alg_state)
|
pop = self.tensorneat_algorithm.ask(self.alg_state)
|
||||||
generation = int(self.alg_state.generation)
|
generation = int(self.alg_state.generation)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user