diff --git a/examples/tmp.py b/examples/tmp.py deleted file mode 100644 index 1d08549..0000000 --- a/examples/tmp.py +++ /dev/null @@ -1,21 +0,0 @@ -import jax, jax.numpy as jnp - -from tensorneat.algorithm import NEAT -from tensorneat.genome import DefaultGenome, RecurrentGenome - -key = jax.random.key(0) -genome = DefaultGenome(num_inputs=5, num_outputs=3, max_nodes=100, max_conns=500, init_hidden_layers=(1, 2 ,3)) -state = genome.setup() -nodes, conns = genome.initialize(state, key) -print(genome.repr(state, nodes, conns)) - -inputs = jnp.array([1, 2, 3, 4, 5]) -transformed = genome.transform(state, nodes, conns) -outputs = genome.forward(state, transformed, inputs) - -print(outputs) - -network = genome.network_dict(state, nodes, conns) -print(network) - -genome.visualize(network) diff --git a/examples/tmp2.py b/examples/tmp2.py deleted file mode 100644 index 626f593..0000000 --- a/examples/tmp2.py +++ /dev/null @@ -1,39 +0,0 @@ -import jax, jax.numpy as jnp - -from tensorneat.pipeline import Pipeline -from tensorneat.algorithm.neat import NEAT -from tensorneat.genome import DefaultGenome, DefaultNode, DefaultMutation, BiasNode -from tensorneat.problem.func_fit import CustomFuncFit -from tensorneat.common import Act, Agg - - -def pagie_polynomial(inputs): - x, y = inputs - return x + y - - -if __name__ == "__main__": - genome=DefaultGenome( - num_inputs=2, - num_outputs=1, - max_nodes=3, - max_conns=2, - init_hidden_layers=(), - node_gene=BiasNode( - activation_options=[Act.identity], - aggregation_options=[Agg.sum], - ), - output_transform=Act.identity, - mutation=DefaultMutation( - node_add=0, - node_delete=0, - conn_add=0.0, - conn_delete=0.0, - ) - ) - randkey = jax.random.PRNGKey(42) - state = genome.setup() - nodes, conns = genome.initialize(state, randkey) - print(genome) - - diff --git a/examples/with_evox/evox_algorithm_adaptor.py b/examples/with_evox/evox_algorithm_adaptor.py new file mode 100644 index 0000000..dea3afb --- /dev/null +++ b/examples/with_evox/evox_algorithm_adaptor.py @@ -0,0 +1,34 @@ +import jax.numpy as jnp + +from evox import Algorithm as EvoXAlgorithm, State as EvoXState, jit_class + +from tensorneat.algorithm import BaseAlgorithm as TensorNEATAlgorithm +from tensorneat.common import State as TensorNEATState + + +@jit_class +class EvoXAlgorithmAdaptor(EvoXAlgorithm): + def __init__(self, algorithm: TensorNEATAlgorithm): + self.algorithm = algorithm + self.fixed_state = None + + def setup(self, key): + neat_algorithm_state = TensorNEATState(randkey=key) + neat_algorithm_state = self.algorithm.setup(neat_algorithm_state) + self.fixed_state = neat_algorithm_state + return EvoXState(alg_state=neat_algorithm_state) + + def ask(self, state: EvoXState): + population = self.algorithm.ask(state.alg_state) + return population, state + + def tell(self, state: EvoXState, fitness): + fitness = jnp.where(jnp.isnan(fitness), -jnp.inf, fitness) + neat_algorithm_state = self.algorithm.tell(state.alg_state, fitness) + return state.replace(alg_state=neat_algorithm_state) + + def transform(self, individual): + return self.algorithm.transform(self.fixed_state, individual) + + def forward(self, transformed, inputs): + return self.algorithm.forward(self.fixed_state, transformed, inputs) diff --git a/examples/with_evox/example.py b/examples/with_evox/example.py new file mode 100644 index 0000000..b5efb22 --- /dev/null +++ b/examples/with_evox/example.py @@ -0,0 +1,65 @@ +import jax +import jax.numpy as jnp + +from evox import workflows, algorithms, problems + +from tensorneat.examples.with_evox.evox_algorithm_adaptor import EvoXAlgorithmAdaptor +from tensorneat.examples.with_evox.tensorneat_monitor import TensorNEATMonitor +from tensorneat.algorithm import NEAT +from tensorneat.algorithm.neat import DefaultSpecies, DefaultGenome, DefaultNodeGene +from tensorneat.common import Act + +neat_algorithm = NEAT( + species=DefaultSpecies( + genome=DefaultGenome( + num_inputs=17, + num_outputs=6, + max_nodes=200, + max_conns=500, + node_gene=DefaultNodeGene( + activation_options=(Act.standard_tanh,), + activation_default=Act.standard_tanh, + ), + output_transform=Act.tanh, + ), + pop_size=10000, + species_size=10, + ), +) +evox_algorithm = EvoXAlgorithmAdaptor(neat_algorithm) + +key = jax.random.PRNGKey(42) +model_key, workflow_key = jax.random.split(key) + +monitor = TensorNEATMonitor(neat_algorithm, is_save=False) +problem = problems.neuroevolution.Brax( + env_name="walker2d", + policy=evox_algorithm.forward, + max_episode_length=1000, + num_episodes=1, + backend="mjx" +) + +def nan2inf(x): + return jnp.where(jnp.isnan(x), -jnp.inf, x) + +# create a workflow +workflow = workflows.StdWorkflow( + algorithm=evox_algorithm, + problem=problem, + candidate_transforms=[jax.jit(jax.vmap(evox_algorithm.transform))], + fitness_transforms=[nan2inf], + monitors=[monitor], + opt_direction="max", +) + +# init the workflow +state = workflow.init(workflow_key) +# state = workflow.enable_multi_devices(state) +# run the workflow for 100 steps +import time + +for i in range(100): + tic = time.time() + train_info, state = workflow.step(state) + monitor.show() \ No newline at end of file diff --git a/examples/with_evox/ray_test.py b/examples/with_evox/ray_test.py deleted file mode 100644 index 0a6e8c9..0000000 --- a/examples/with_evox/ray_test.py +++ /dev/null @@ -1,6 +0,0 @@ -import ray - -ray.init(num_gpus=2) - -available_resources = ray.available_resources() -print("Available resources:", available_resources) diff --git a/examples/with_evox/tensorneat_monitor.py b/examples/with_evox/tensorneat_monitor.py new file mode 100644 index 0000000..05261fe --- /dev/null +++ b/examples/with_evox/tensorneat_monitor.py @@ -0,0 +1,133 @@ +import warnings +import os +import time +import numpy as np + +import jax +from jax.experimental import io_callback +from evox import Monitor +from evox import State as EvoXState + +from tensorneat.algorithm import BaseAlgorithm as TensorNEATAlgorithm +from tensorneat.common import State as TensorNEATState + + +class TensorNEATMonitor(Monitor): + + def __init__( + self, + neat_algorithm: TensorNEATAlgorithm, + save_dir: str = None, + is_save: bool = False, + ): + super().__init__() + self.neat_algorithm = neat_algorithm + + self.generation_timestamp = time.time() + self.alg_state: TensorNEATState = None + self.fitness = None + self.best_fitness = -np.inf + self.best_genome = None + + self.is_save = is_save + + if is_save: + if save_dir is None: + now = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + self.save_dir = f"./{self.__class__.__name__} {now}" + else: + self.save_dir = save_dir + print(f"save to {self.save_dir}") + if not os.path.exists(self.save_dir): + os.makedirs(self.save_dir) + self.genome_dir = os.path.join(self.save_dir, "genomes") + if not os.path.exists(self.genome_dir): + os.makedirs(self.genome_dir) + + def hooks(self): + return ["pre_tell"] + + def pre_tell(self, state: EvoXState, cand_sol, transformed_cand_sol, fitness, transformed_fitness): + io_callback( + self.store_info, + None, + state, + transformed_fitness, + ) + + def store_info(self, state: EvoXState, fitness): + self.alg_state: TensorNEATState = state.query_state("algorithm").alg_state + self.fitness = jax.device_get(fitness) + + def show(self): + pop = self.neat_algorithm.ask(self.alg_state) + valid_fitnesses = self.fitness[~np.isinf(self.fitness)] + + max_f, min_f, mean_f, std_f = ( + max(valid_fitnesses), + min(valid_fitnesses), + np.mean(valid_fitnesses), + np.std(valid_fitnesses), + ) + + new_timestamp = time.time() + + cost_time = new_timestamp - self.generation_timestamp + self.generation_timestamp = new_timestamp + + max_idx = np.argmax(self.fitness) + if self.fitness[max_idx] > self.best_fitness: + self.best_fitness = self.fitness[max_idx] + self.best_genome = pop[0][max_idx], pop[1][max_idx] + + if self.is_save: + best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx])) + with open( + os.path.join( + self.genome_dir, + f"{int(self.neat_algorithm.generation(self.alg_state))}.npz", + ), + "wb", + ) as f: + np.savez( + f, + nodes=best_genome[0], + conns=best_genome[1], + fitness=self.best_fitness, + ) + + # save best if save path is not None + member_count = jax.device_get(self.neat_algorithm.member_count(self.alg_state)) + species_sizes = [int(i) for i in member_count if i > 0] + + pop = jax.device_get(pop) + pop_nodes, pop_conns = pop # (P, N, NL), (P, C, CL) + nodes_cnt = (~np.isnan(pop_nodes[:, :, 0])).sum(axis=1) # (P,) + conns_cnt = (~np.isnan(pop_conns[:, :, 0])).sum(axis=1) # (P,) + + max_node_cnt, min_node_cnt, mean_node_cnt = ( + max(nodes_cnt), + min(nodes_cnt), + np.mean(nodes_cnt), + ) + + max_conn_cnt, min_conn_cnt, mean_conn_cnt = ( + max(conns_cnt), + min(conns_cnt), + np.mean(conns_cnt), + ) + + print( + f"Generation: {self.neat_algorithm.generation(self.alg_state)}, Cost time: {cost_time * 1000:.2f}ms\n", + f"\tnode counts: max: {max_node_cnt}, min: {min_node_cnt}, mean: {mean_node_cnt:.2f}\n", + f"\tconn counts: max: {max_conn_cnt}, min: {min_conn_cnt}, mean: {mean_conn_cnt:.2f}\n", + f"\tspecies: {len(species_sizes)}, {species_sizes}\n", + f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n", + ) + + # append log + if self.is_save: + with open(os.path.join(self.save_dir, "log.txt"), "a") as f: + f.write( + f"{self.neat_algorithm.generation(self.alg_state)},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n" + )