Add examples for saving evolving state and then restore evolving
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -118,3 +118,4 @@ cython_debug/
|
|||||||
|
|
||||||
tutorials/.ipynb_checkpoints/*
|
tutorials/.ipynb_checkpoints/*
|
||||||
docs/_build/*
|
docs/_build/*
|
||||||
|
examples/func_fit/evolving_state.pkl
|
||||||
52
examples/func_fit/xor_restore_evolving.py
Normal file
52
examples/func_fit/xor_restore_evolving.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import jax
|
||||||
|
import numpy as np
|
||||||
|
from tensorneat.common import State
|
||||||
|
from tensorneat.pipeline import Pipeline
|
||||||
|
from tensorneat import algorithm, genome, problem
|
||||||
|
from tensorneat.common import ACT
|
||||||
|
|
||||||
|
# neccessary settings
|
||||||
|
algorithm = algorithm.NEAT(
|
||||||
|
pop_size=1000,
|
||||||
|
species_size=20,
|
||||||
|
survival_threshold=0.01,
|
||||||
|
genome=genome.DefaultGenome(
|
||||||
|
num_inputs=3,
|
||||||
|
num_outputs=1,
|
||||||
|
max_nodes=7,
|
||||||
|
output_transform=ACT.sigmoid,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
problem = problem.XOR3d()
|
||||||
|
|
||||||
|
pipeline = Pipeline(
|
||||||
|
algorithm,
|
||||||
|
problem,
|
||||||
|
generation_limit=200, # actually useless when we don't using auto_run()
|
||||||
|
fitness_target=-1e-6, # actually useless when we don't using auto_run()
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load the previous evolving state
|
||||||
|
state = State.load("./evolving_state.pkl")
|
||||||
|
print("load the evolving state from ./evolving_state.pkl")
|
||||||
|
|
||||||
|
|
||||||
|
# compile step to speed up
|
||||||
|
compiled_step = jax.jit(pipeline.step).lower(state).compile()
|
||||||
|
|
||||||
|
current_generation = 0
|
||||||
|
# run 50 generations
|
||||||
|
for i in range(50):
|
||||||
|
state, previous_pop, fitnesses = compiled_step(state)
|
||||||
|
fitnesses = jax.device_get(fitnesses) # move fitness from gpu to cpu for printing
|
||||||
|
print(f"Generation {current_generation}, best fitness: {max(fitnesses)}")
|
||||||
|
current_generation += 1
|
||||||
|
|
||||||
|
# obtain the best individual
|
||||||
|
best_idx = np.argmax(fitnesses)
|
||||||
|
best_nodes, best_conns = previous_pop[0][best_idx], previous_pop[1][best_idx]
|
||||||
|
# make it inference
|
||||||
|
transformed = algorithm.genome.transform(state, best_nodes, best_conns)
|
||||||
|
xor3d_outputs = jax.vmap(algorithm.genome.forward, in_axes=(None, None, 0))(state, transformed, problem.inputs)
|
||||||
|
print(f"{xor3d_outputs=}")
|
||||||
51
examples/func_fit/xor_save_the_evolving_state.py
Normal file
51
examples/func_fit/xor_save_the_evolving_state.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import jax
|
||||||
|
import numpy as np
|
||||||
|
from tensorneat.pipeline import Pipeline
|
||||||
|
from tensorneat import algorithm, genome, problem
|
||||||
|
from tensorneat.common import ACT
|
||||||
|
|
||||||
|
# neccessary settings
|
||||||
|
algorithm = algorithm.NEAT(
|
||||||
|
pop_size=1000,
|
||||||
|
species_size=20,
|
||||||
|
survival_threshold=0.01,
|
||||||
|
genome=genome.DefaultGenome(
|
||||||
|
num_inputs=3,
|
||||||
|
num_outputs=1,
|
||||||
|
max_nodes=7,
|
||||||
|
output_transform=ACT.sigmoid,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
problem = problem.XOR3d()
|
||||||
|
|
||||||
|
pipeline = Pipeline(
|
||||||
|
algorithm,
|
||||||
|
problem,
|
||||||
|
generation_limit=200, # actually useless when we don't using auto_run()
|
||||||
|
fitness_target=-1e-6, # actually useless when we don't using auto_run()
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
state = pipeline.setup()
|
||||||
|
|
||||||
|
# compile step to speed up
|
||||||
|
compiled_step = jax.jit(pipeline.step).lower(state).compile()
|
||||||
|
|
||||||
|
current_generation = 0
|
||||||
|
# run 50 generations
|
||||||
|
for i in range(50):
|
||||||
|
state, previous_pop, fitnesses = compiled_step(state)
|
||||||
|
fitnesses = jax.device_get(fitnesses) # move fitness from gpu to cpu for printing
|
||||||
|
print(f"Generation {current_generation}, best fitness: {max(fitnesses)}")
|
||||||
|
current_generation += 1
|
||||||
|
|
||||||
|
# obtain the best individual
|
||||||
|
best_idx = np.argmax(fitnesses)
|
||||||
|
best_nodes, best_conns = previous_pop[0][best_idx], previous_pop[1][best_idx]
|
||||||
|
# make it inference
|
||||||
|
transformed = algorithm.genome.transform(state, best_nodes, best_conns)
|
||||||
|
xor3d_outputs = jax.vmap(algorithm.genome.forward, in_axes=(None, None, 0))(state, transformed, problem.inputs)
|
||||||
|
print(f"{xor3d_outputs=}")
|
||||||
|
|
||||||
|
# save the evolving state
|
||||||
|
state.save("./evolving_state.pkl")
|
||||||
|
print("save the evolving state to ./evolving_state.pkl")
|
||||||
@@ -83,6 +83,13 @@ class Pipeline(StatefulBaseClass):
|
|||||||
return state
|
return state
|
||||||
|
|
||||||
def step(self, state):
|
def step(self, state):
|
||||||
|
"""
|
||||||
|
returns:
|
||||||
|
state, previous_pop, fitnesses
|
||||||
|
state: updated state
|
||||||
|
previous_pop: previous population
|
||||||
|
fitnesses: fitnesses of previous population
|
||||||
|
"""
|
||||||
|
|
||||||
randkey_, randkey = jax.random.split(state.randkey)
|
randkey_, randkey = jax.random.split(state.randkey)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user