diff --git a/.gitignore b/.gitignore index 28be23c..e80a817 100644 --- a/.gitignore +++ b/.gitignore @@ -118,3 +118,4 @@ cython_debug/ tutorials/.ipynb_checkpoints/* docs/_build/* +examples/func_fit/evolving_state.pkl \ No newline at end of file diff --git a/examples/func_fit/xor_restore_evolving.py b/examples/func_fit/xor_restore_evolving.py new file mode 100644 index 0000000..74075e1 --- /dev/null +++ b/examples/func_fit/xor_restore_evolving.py @@ -0,0 +1,52 @@ +import jax +import numpy as np +from tensorneat.common import State +from tensorneat.pipeline import Pipeline +from tensorneat import algorithm, genome, problem +from tensorneat.common import ACT + +# neccessary settings +algorithm = algorithm.NEAT( + pop_size=1000, + species_size=20, + survival_threshold=0.01, + genome=genome.DefaultGenome( + num_inputs=3, + num_outputs=1, + max_nodes=7, + output_transform=ACT.sigmoid, + ), +) +problem = problem.XOR3d() + +pipeline = Pipeline( + algorithm, + problem, + generation_limit=200, # actually useless when we don't using auto_run() + fitness_target=-1e-6, # actually useless when we don't using auto_run() + seed=42, +) + +# load the previous evolving state +state = State.load("./evolving_state.pkl") +print("load the evolving state from ./evolving_state.pkl") + + +# compile step to speed up +compiled_step = jax.jit(pipeline.step).lower(state).compile() + +current_generation = 0 +# run 50 generations +for i in range(50): + state, previous_pop, fitnesses = compiled_step(state) + fitnesses = jax.device_get(fitnesses) # move fitness from gpu to cpu for printing + print(f"Generation {current_generation}, best fitness: {max(fitnesses)}") + current_generation += 1 + +# obtain the best individual +best_idx = np.argmax(fitnesses) +best_nodes, best_conns = previous_pop[0][best_idx], previous_pop[1][best_idx] +# make it inference +transformed = algorithm.genome.transform(state, best_nodes, best_conns) +xor3d_outputs = jax.vmap(algorithm.genome.forward, in_axes=(None, None, 0))(state, transformed, problem.inputs) +print(f"{xor3d_outputs=}") \ No newline at end of file diff --git a/examples/func_fit/xor_save_the_evolving_state.py b/examples/func_fit/xor_save_the_evolving_state.py new file mode 100644 index 0000000..9159eb7 --- /dev/null +++ b/examples/func_fit/xor_save_the_evolving_state.py @@ -0,0 +1,51 @@ +import jax +import numpy as np +from tensorneat.pipeline import Pipeline +from tensorneat import algorithm, genome, problem +from tensorneat.common import ACT + +# neccessary settings +algorithm = algorithm.NEAT( + pop_size=1000, + species_size=20, + survival_threshold=0.01, + genome=genome.DefaultGenome( + num_inputs=3, + num_outputs=1, + max_nodes=7, + output_transform=ACT.sigmoid, + ), +) +problem = problem.XOR3d() + +pipeline = Pipeline( + algorithm, + problem, + generation_limit=200, # actually useless when we don't using auto_run() + fitness_target=-1e-6, # actually useless when we don't using auto_run() + seed=42, +) +state = pipeline.setup() + +# compile step to speed up +compiled_step = jax.jit(pipeline.step).lower(state).compile() + +current_generation = 0 +# run 50 generations +for i in range(50): + state, previous_pop, fitnesses = compiled_step(state) + fitnesses = jax.device_get(fitnesses) # move fitness from gpu to cpu for printing + print(f"Generation {current_generation}, best fitness: {max(fitnesses)}") + current_generation += 1 + +# obtain the best individual +best_idx = np.argmax(fitnesses) +best_nodes, best_conns = previous_pop[0][best_idx], previous_pop[1][best_idx] +# make it inference +transformed = algorithm.genome.transform(state, best_nodes, best_conns) +xor3d_outputs = jax.vmap(algorithm.genome.forward, in_axes=(None, None, 0))(state, transformed, problem.inputs) +print(f"{xor3d_outputs=}") + +# save the evolving state +state.save("./evolving_state.pkl") +print("save the evolving state to ./evolving_state.pkl") \ No newline at end of file diff --git a/src/tensorneat/pipeline.py b/src/tensorneat/pipeline.py index f85516c..95b1ce3 100644 --- a/src/tensorneat/pipeline.py +++ b/src/tensorneat/pipeline.py @@ -22,6 +22,7 @@ class Pipeline(StatefulBaseClass): is_save: bool = False, save_dir=None, show_problem_details: bool = False, + using_multidevice: bool = False, ): assert problem.jitable, "Currently, problem must be jitable" @@ -58,6 +59,11 @@ class Pipeline(StatefulBaseClass): self.show_problem_details = show_problem_details + self.using_multidevice = using_multidevice + if self.using_multidevice: + assert jax.device_count() > 1, f"using_multidevice requires more than 1 device, but {jax.device_count()=} devices are available" + print(f"Using {jax.device_count()} devices!") + def setup(self, state=State()): print("initializing") state = state.register(randkey=jax.random.PRNGKey(self.seed)) @@ -77,6 +83,13 @@ class Pipeline(StatefulBaseClass): return state def step(self, state): + """ + returns: + state, previous_pop, fitnesses + state: updated state + previous_pop: previous population + fitnesses: fitnesses of previous population + """ randkey_, randkey = jax.random.split(state.randkey) @@ -86,12 +99,12 @@ class Pipeline(StatefulBaseClass): state, pop ) - if jax.device_count() == 1: + if not self.using_multidevice: keys = jax.random.split(randkey_, self.pop_size) fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))( state, keys, self.algorithm.forward, pop_transformed ) - else: + else: # using_multidevice num_devices = jax.device_count() assert self.pop_size % num_devices == 0, "if you want to use multiple gpus, pop_size must be divisible by jax.device_count()" pop_size_per_device = self.pop_size // num_devices