Merge branch 'EMI-Group:main' into main

This commit is contained in:
WhymustIhaveaname
2025-02-23 18:17:28 +08:00
committed by GitHub
4 changed files with 119 additions and 2 deletions

1
.gitignore vendored
View File

@@ -118,3 +118,4 @@ cython_debug/
tutorials/.ipynb_checkpoints/*
docs/_build/*
examples/func_fit/evolving_state.pkl

View File

@@ -0,0 +1,52 @@
import jax
import numpy as np
from tensorneat.common import State
from tensorneat.pipeline import Pipeline
from tensorneat import algorithm, genome, problem
from tensorneat.common import ACT
# neccessary settings
algorithm = algorithm.NEAT(
pop_size=1000,
species_size=20,
survival_threshold=0.01,
genome=genome.DefaultGenome(
num_inputs=3,
num_outputs=1,
max_nodes=7,
output_transform=ACT.sigmoid,
),
)
problem = problem.XOR3d()
pipeline = Pipeline(
algorithm,
problem,
generation_limit=200, # actually useless when we don't using auto_run()
fitness_target=-1e-6, # actually useless when we don't using auto_run()
seed=42,
)
# load the previous evolving state
state = State.load("./evolving_state.pkl")
print("load the evolving state from ./evolving_state.pkl")
# compile step to speed up
compiled_step = jax.jit(pipeline.step).lower(state).compile()
current_generation = 0
# run 50 generations
for i in range(50):
state, previous_pop, fitnesses = compiled_step(state)
fitnesses = jax.device_get(fitnesses) # move fitness from gpu to cpu for printing
print(f"Generation {current_generation}, best fitness: {max(fitnesses)}")
current_generation += 1
# obtain the best individual
best_idx = np.argmax(fitnesses)
best_nodes, best_conns = previous_pop[0][best_idx], previous_pop[1][best_idx]
# make it inference
transformed = algorithm.genome.transform(state, best_nodes, best_conns)
xor3d_outputs = jax.vmap(algorithm.genome.forward, in_axes=(None, None, 0))(state, transformed, problem.inputs)
print(f"{xor3d_outputs=}")

View File

@@ -0,0 +1,51 @@
import jax
import numpy as np
from tensorneat.pipeline import Pipeline
from tensorneat import algorithm, genome, problem
from tensorneat.common import ACT
# neccessary settings
algorithm = algorithm.NEAT(
pop_size=1000,
species_size=20,
survival_threshold=0.01,
genome=genome.DefaultGenome(
num_inputs=3,
num_outputs=1,
max_nodes=7,
output_transform=ACT.sigmoid,
),
)
problem = problem.XOR3d()
pipeline = Pipeline(
algorithm,
problem,
generation_limit=200, # actually useless when we don't using auto_run()
fitness_target=-1e-6, # actually useless when we don't using auto_run()
seed=42,
)
state = pipeline.setup()
# compile step to speed up
compiled_step = jax.jit(pipeline.step).lower(state).compile()
current_generation = 0
# run 50 generations
for i in range(50):
state, previous_pop, fitnesses = compiled_step(state)
fitnesses = jax.device_get(fitnesses) # move fitness from gpu to cpu for printing
print(f"Generation {current_generation}, best fitness: {max(fitnesses)}")
current_generation += 1
# obtain the best individual
best_idx = np.argmax(fitnesses)
best_nodes, best_conns = previous_pop[0][best_idx], previous_pop[1][best_idx]
# make it inference
transformed = algorithm.genome.transform(state, best_nodes, best_conns)
xor3d_outputs = jax.vmap(algorithm.genome.forward, in_axes=(None, None, 0))(state, transformed, problem.inputs)
print(f"{xor3d_outputs=}")
# save the evolving state
state.save("./evolving_state.pkl")
print("save the evolving state to ./evolving_state.pkl")

View File

@@ -22,6 +22,7 @@ class Pipeline(StatefulBaseClass):
is_save: bool = False,
save_dir=None,
show_problem_details: bool = False,
using_multidevice: bool = False,
):
assert problem.jitable, "Currently, problem must be jitable"
@@ -58,6 +59,11 @@ class Pipeline(StatefulBaseClass):
self.show_problem_details = show_problem_details
self.using_multidevice = using_multidevice
if self.using_multidevice:
assert jax.device_count() > 1, f"using_multidevice requires more than 1 device, but {jax.device_count()=} devices are available"
print(f"Using {jax.device_count()} devices!")
def setup(self, state=State()):
print("initializing")
state = state.register(randkey=jax.random.PRNGKey(self.seed))
@@ -77,6 +83,13 @@ class Pipeline(StatefulBaseClass):
return state
def step(self, state):
"""
returns:
state, previous_pop, fitnesses
state: updated state
previous_pop: previous population
fitnesses: fitnesses of previous population
"""
randkey_, randkey = jax.random.split(state.randkey)
@@ -86,12 +99,12 @@ class Pipeline(StatefulBaseClass):
state, pop
)
if jax.device_count() == 1:
if not self.using_multidevice:
keys = jax.random.split(randkey_, self.pop_size)
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
state, keys, self.algorithm.forward, pop_transformed
)
else:
else: # using_multidevice
num_devices = jax.device_count()
assert self.pop_size % num_devices == 0, "if you want to use multiple gpus, pop_size must be divisible by jax.device_count()"
pop_size_per_device = self.pop_size // num_devices