diff --git a/tensorneat/algorithm/base.py b/tensorneat/algorithm/base.py index d9677f2..23999c4 100644 --- a/tensorneat/algorithm/base.py +++ b/tensorneat/algorithm/base.py @@ -19,9 +19,15 @@ class BaseAlgorithm: """transform the genome into a neural network""" raise NotImplementedError + def restore(self, state, transformed): + raise NotImplementedError + def forward(self, state, inputs, transformed): raise NotImplementedError + def update_by_batch(self, state, batch_input, transformed): + raise NotImplementedError + @property def num_inputs(self): raise NotImplementedError diff --git a/tensorneat/algorithm/neat/ga/mutation/default.py b/tensorneat/algorithm/neat/ga/mutation/default.py index 7bb32c8..81252bd 100644 --- a/tensorneat/algorithm/neat/ga/mutation/default.py +++ b/tensorneat/algorithm/neat/ga/mutation/default.py @@ -178,18 +178,25 @@ class DefaultMutation(BaseMutation): def no(key_, nodes_, conns_): return nodes_, conns_ - nodes, conns = jax.lax.cond( - r1 < self.node_add, mutate_add_node, no, k1, nodes, conns - ) - nodes, conns = jax.lax.cond( - r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns - ) - nodes, conns = jax.lax.cond( - r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns - ) - nodes, conns = jax.lax.cond( - r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns - ) + if self.node_add > 0: + nodes, conns = jax.lax.cond( + r1 < self.node_add, mutate_add_node, no, k1, nodes, conns + ) + + if self.node_delete > 0: + nodes, conns = jax.lax.cond( + r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns + ) + + if self.conn_add > 0: + nodes, conns = jax.lax.cond( + r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns + ) + + if self.conn_delete > 0: + nodes, conns = jax.lax.cond( + r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns + ) return nodes, conns diff --git a/tensorneat/algorithm/neat/genome/default.py b/tensorneat/algorithm/neat/genome/default.py index b880fc9..a9ed7f2 100644 --- a/tensorneat/algorithm/neat/genome/default.py +++ b/tensorneat/algorithm/neat/genome/default.py @@ -117,7 +117,9 @@ class DefaultGenome(BaseGenome): def hit(): batch_ins, new_conn_attrs = jax.vmap( - self.conn_gene.update_by_batch, in_axes=(None, 1, 1), out_axes=(1, 1) + self.conn_gene.update_by_batch, + in_axes=(None, 1, 1), + out_axes=(1, 1), )(state, u_conns_[:, :, i], batch_values) batch_z, new_node_attrs = self.node_gene.update_by_batch( state, @@ -132,12 +134,12 @@ class DefaultGenome(BaseGenome): u_conns_.at[:, :, i].set(new_conn_attrs), ) + # the val of input nodes is obtained by the task, not by calculation (batch_values, nodes_attrs_, u_conns_) = jax.lax.cond( jnp.isin(i, self.input_idx), lambda: (batch_values, nodes_attrs_, u_conns_), hit, ) - # the val of input nodes is obtained by the task, not by calculation return batch_values, nodes_attrs_, u_conns_, idx + 1 diff --git a/tensorneat/algorithm/neat/neat.py b/tensorneat/algorithm/neat/neat.py index f0e65e0..fefc6fa 100644 --- a/tensorneat/algorithm/neat/neat.py +++ b/tensorneat/algorithm/neat/neat.py @@ -44,9 +44,15 @@ class NEAT(BaseAlgorithm): nodes, conns = individual return self.genome.transform(state, nodes, conns) + def restore(self, state, transformed): + return self.genome.restore(state, transformed) + def forward(self, state, inputs, transformed): return self.genome.forward(state, inputs, transformed) + def update_by_batch(self, state, batch_input, transformed): + return self.genome.update_by_batch(state, batch_input, transformed) + @property def num_inputs(self): return self.genome.num_inputs diff --git a/tensorneat/algorithm/neat/species/default.py b/tensorneat/algorithm/neat/species/default.py index f82141c..83226f4 100644 --- a/tensorneat/algorithm/neat/species/default.py +++ b/tensorneat/algorithm/neat/species/default.py @@ -113,6 +113,9 @@ class DefaultSpecies(BaseSpecies): return state.pop_nodes, state.pop_conns def update_species(self, state, fitness): + # set nan to -inf + fitness = jnp.where(jnp.isnan(fitness), -jnp.inf, fitness) + # update the fitness of each species state, species_fitness = self.update_species_fitness(state, fitness) @@ -121,6 +124,7 @@ class DefaultSpecies(BaseSpecies): # sort species_info by their fitness. (also push nan to the end) sort_indices = jnp.argsort(species_fitness)[::-1] + state = state.update( species_keys=state.species_keys[sort_indices], best_fitness=state.best_fitness[sort_indices], diff --git a/tensorneat/examples/func_fit/xor.py b/tensorneat/examples/func_fit/xor.py index 3ced8b6..8e8ffd8 100644 --- a/tensorneat/examples/func_fit/xor.py +++ b/tensorneat/examples/func_fit/xor.py @@ -21,11 +21,11 @@ if __name__ == "__main__": mutation=DefaultMutation( node_add=0.05, conn_add=0.05, - node_delete=0, - conn_delete=0, + node_delete=0.05, + conn_delete=0.05, ), ), - pop_size=100, + pop_size=1000, species_size=20, compatibility_threshold=2, survival_threshold=0.01, # magic diff --git a/tensorneat/pipeline.py b/tensorneat/pipeline.py index b3d7990..8eca436 100644 --- a/tensorneat/pipeline.py +++ b/tensorneat/pipeline.py @@ -1,11 +1,11 @@ -from functools import partial - import jax, jax.numpy as jnp import time import numpy as np from algorithm import BaseAlgorithm from problem import BaseProblem +from problem.rl_env import RLEnv +from problem.func_fit import FuncFit from utils import State @@ -17,6 +17,8 @@ class Pipeline: seed: int = 42, fitness_target: float = 1, generation_limit: int = 1000, + pre_update: bool = False, + update_batch_size: int = 10000, ): assert problem.jitable, "Currently, problem must be jitable" @@ -37,10 +39,30 @@ class Pipeline: self.best_genome = None self.best_fitness = float("-inf") self.generation_timestamp = None + self.pre_update = pre_update + self.update_batch_size = update_batch_size + if pre_update: + if isinstance(problem, RLEnv): + assert problem.record_episode, "record_episode must be True" + self.fetch_data = lambda episode: episode["obs"] + elif isinstance(problem, FuncFit): + assert problem.return_data, "return_data must be True" + self.fetch_data = lambda data: data + else: + raise NotImplementedError def setup(self, state=State()): print("initializing") state = state.register(randkey=jax.random.PRNGKey(self.seed)) + + if self.pre_update: + # initial with mean = 0 and std = 1 + state = state.register( + data=jax.random.normal( + state.randkey, (self.update_batch_size, self.algorithm.num_inputs) + ) + ) + state = self.algorithm.setup(state) state = self.problem.setup(state) print("initializing finished") @@ -57,9 +79,42 @@ class Pipeline: state, pop ) - fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))( - state, keys, self.algorithm.forward, pop_transformed - ) + if self.pre_update: + # update the population + _, pop_transformed = jax.vmap( + self.algorithm.update_by_batch, in_axes=(None, None, 0) + )(state, state.data, pop_transformed) + + # raw_data: (Pop, Batch, num_inputs) + fitnesses, raw_data = jax.vmap( + self.problem.evaluate, in_axes=(None, 0, None, 0) + )(state, keys, self.algorithm.forward, pop_transformed) + + data = self.fetch_data(raw_data) + assert ( + data.ndim == 3 + and data.shape[0] == self.pop_size + and data.shape[2] == self.algorithm.num_inputs + ) + # reshape to (Pop * Batch, num_inputs) + data = data.reshape( + data.shape[0] * data.shape[1], self.algorithm.num_inputs + ) + # shuffle + data = jax.random.permutation(randkey_, data, axis=0) + # cutoff or expand + if data.shape[0] >= self.update_batch_size: + data = data[: self.update_batch_size] # cutoff + else: + data = ( + jnp.full(state.data.shape, jnp.nan).at[: data.shape[0]].set(data) + ) # expand + state = state.update(data=data) + + else: + fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))( + state, keys, self.algorithm.forward, pop_transformed + ) state = self.algorithm.tell(state, fitnesses) @@ -89,24 +144,18 @@ class Pipeline: print("Fitness limit reached!") return state, self.best_genome - # node = previous_pop[0][0][:, 0] - # node_count = jnp.sum(~jnp.isnan(node)) - # conn = previous_pop[1][0][:, 0] - # conn_count = jnp.sum(~jnp.isnan(conn)) - # if (w % 5 == 0): - # print("node_count", node_count) - # print("conn_count", conn_count) - print("Generation limit reached!") return state, self.best_genome def analysis(self, state, pop, fitnesses): + valid_fitnesses = fitnesses[~np.isnan(fitnesses)] + max_f, min_f, mean_f, std_f = ( - max(fitnesses), - min(fitnesses), - np.mean(fitnesses), - np.std(fitnesses), + max(valid_fitnesses), + min(valid_fitnesses), + np.mean(valid_fitnesses), + np.std(valid_fitnesses), ) new_timestamp = time.time() @@ -122,9 +171,9 @@ class Pipeline: species_sizes = [int(i) for i in member_count if i > 0] print( - f"Generation: {self.algorithm.generation(state)}", - f"species: {len(species_sizes)}, {species_sizes}", - f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms", + f"Generation: {self.algorithm.generation(state)}, Cost time: {cost_time * 1000:.2f}ms\n", + f"\tspecies: {len(species_sizes)}, {species_sizes}\n", + f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n", ) def show(self, state, best, *args, **kwargs): diff --git a/tensorneat/problem/func_fit/func_fit.py b/tensorneat/problem/func_fit/func_fit.py index 31b5003..e6cc70d 100644 --- a/tensorneat/problem/func_fit/func_fit.py +++ b/tensorneat/problem/func_fit/func_fit.py @@ -49,7 +49,10 @@ class FuncFit(BaseProblem): state, self.inputs, params ) inputs, target, predict = jax.device_get([self.inputs, self.targets, predict]) - loss = self.evaluate(state, randkey, act_func, params) + if self.return_data: + loss, _ = self.evaluate(state, randkey, act_func, params) + else: + loss = self.evaluate(state, randkey, act_func, params) loss = -loss msg = "" diff --git a/tensorneat/problem/func_fit/xor.py b/tensorneat/problem/func_fit/xor.py index c798b85..8f32ef2 100644 --- a/tensorneat/problem/func_fit/xor.py +++ b/tensorneat/problem/func_fit/xor.py @@ -4,14 +4,19 @@ from .func_fit import FuncFit class XOR(FuncFit): - @property def inputs(self): - return np.array([[0, 0], [0, 1], [1, 0], [1, 1]]) + return np.array( + [[0, 0], [0, 1], [1, 0], [1, 1]], + dtype=np.float32, + ) @property def targets(self): - return np.array([[0], [1], [1], [0]]) + return np.array( + [[0], [1], [1], [0]], + dtype=np.float32, + ) @property def input_shape(self): diff --git a/tensorneat/problem/func_fit/xor3d.py b/tensorneat/problem/func_fit/xor3d.py index 94807a0..ffe3b6a 100644 --- a/tensorneat/problem/func_fit/xor3d.py +++ b/tensorneat/problem/func_fit/xor3d.py @@ -16,12 +16,16 @@ class XOR3d(FuncFit): [1, 0, 1], [1, 1, 0], [1, 1, 1], - ] + ], + dtype=np.float32, ) @property def targets(self): - return np.array([[0], [1], [1], [0], [1], [0], [0], [1]]) + return np.array( + [[0], [1], [1], [0], [1], [0], [0], [1]], + dtype=np.float32, + ) @property def input_shape(self): diff --git a/tensorneat/problem/rl_env/__init__.py b/tensorneat/problem/rl_env/__init__.py index ac447c1..d473897 100644 --- a/tensorneat/problem/rl_env/__init__.py +++ b/tensorneat/problem/rl_env/__init__.py @@ -1,2 +1,3 @@ from .gymnax_env import GymNaxEnv from .brax_env import BraxEnv +from .rl_jit import RLEnv diff --git a/tensorneat/utils/__init__.py b/tensorneat/utils/__init__.py index 7de0984..2b1711d 100644 --- a/tensorneat/utils/__init__.py +++ b/tensorneat/utils/__init__.py @@ -1,5 +1,5 @@ -from .activation import Act, act -from .aggregation import Agg, agg +from .activation import Act, act, ACT_ALL +from .aggregation import Agg, agg, AGG_ALL from .tools import * from .graph import * from .state import State