From 1ff10b89bd3b7b39d2df816da1cae78b2e1298a5 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Sat, 19 Apr 2025 20:15:52 +0800 Subject: [PATCH] add new example (xor_without_pipeline) --- examples/func_fit/xor_without_pipeline.py | 86 +++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 examples/func_fit/xor_without_pipeline.py diff --git a/examples/func_fit/xor_without_pipeline.py b/examples/func_fit/xor_without_pipeline.py new file mode 100644 index 0000000..909ed3c --- /dev/null +++ b/examples/func_fit/xor_without_pipeline.py @@ -0,0 +1,86 @@ +import jax, jax.numpy as jnp +import numpy as np +from tensorneat import algorithm, genome, problem +from tensorneat.common import ACT, State + +POPSIZE = 10000 + +# STEP 0: create initial state which contains randkey +state = State( + randkey = jax.random.key(42) +) + +# STEP 1: define the NEAT algorithm and prepare necessary functions +algorithm = algorithm.NEAT( + pop_size=POPSIZE, + species_size=20, + survival_threshold=0.01, + genome=genome.DefaultGenome( + num_inputs=3, + num_outputs=1, + max_nodes=7, + output_transform=ACT.sigmoid, + ), +) +population_transform = jax.vmap(algorithm.transform, in_axes=(None, 0)) +jit_population_transform = jax.jit(population_transform) +jit_algorithm_tell = jax.jit(algorithm.tell) + +state = algorithm.setup(state) # setup algorithm and then save infos in previous created state + + +# STEP 2: define the function that return the fitness of population +problem = problem.XOR3d() # problem that used in pipeline +state = problem.setup(state) # setup algorithm and then save infos in previous created state + +# function for evaluate a single network +def fitness_single(state, randkey, transformed): + return problem.evaluate(state, randkey, algorithm.forward, transformed) + +# funciton for evalute the population +fitness_population = jax.vmap(fitness_single, in_axes=(None, 0, 0)) +jit_fitness_population = jax.jit(fitness_population) + + +# STEP 3: Run NEAT algorithm to solve the problem +while True: + population = algorithm.ask(state) + # network in TensorNEAT need to be transformed before calculation + pop_transformed = jit_population_transform(state, population) + + randkeys = jax.random.split(state.randkey, POPSIZE) + # evaluate networks and obtain their fitness + pop_fitness = jit_fitness_population(state, randkeys, pop_transformed) + # replace nan with -inf, necessary + pop_fitness = jnp.where(jnp.isnan(pop_fitness), -jnp.inf, pop_fitness) + + # Do whatever you want here + cpu_pop_fitness = jax.device_get(pop_fitness) + best_idx = np.argmax(cpu_pop_fitness) + print(f"best fitness: {cpu_pop_fitness[best_idx]}") + + if cpu_pop_fitness[best_idx] > -1e-6: # stop check + best = ( + population[0][best_idx], + population[1][best_idx], + ) # population = (pop_nodes, pop_conns) + break + + # tell the fitness to algorithm and update population + state = jit_algorithm_tell(state, pop_fitness) + + +# STEP 4: Do anything you want to the best network +network = algorithm.genome.network_dict(state, *best) +print(algorithm.genome.repr(state, *best)) + +# validate the output for the best +best_transformed = algorithm.transform(state, best) +output = algorithm.forward(state, best_transformed, jnp.array([1, 0, 1])) +print(f"{output=}") + + + + + +