From 075460f896c495453d5c0b3cacec9f757d5473c7 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Thu, 20 Jun 2024 16:32:52 +0800 Subject: [PATCH] fix bugs --- tensorneat/algorithm/neat/gene/base.py | 5 +- tensorneat/algorithm/neat/genome/base.py | 8 ++- tensorneat/algorithm/neat/genome/default.py | 34 +++++++--- tensorneat/algorithm/neat/neat.py | 63 +------------------ tensorneat/algorithm/neat/species/base.py | 3 + tensorneat/algorithm/neat/species/default.py | 53 +++++++++++++++- tensorneat/examples/func_fit/xor.py | 8 +-- tensorneat/examples/gymnax/cartpole.py | 6 +- .../interpret_visualize/genome_sympy.py | 31 ++++----- tensorneat/pipeline.py | 38 +++++++---- tensorneat/utils/__init__.py | 4 +- tensorneat/utils/activation/act_jnp.py | 13 +++- tensorneat/utils/activation/act_sympy.py | 60 ++++++++++-------- tensorneat/utils/aggregation/agg_jnp.py | 8 +-- tensorneat/utils/state.py | 3 + tensorneat/utils/stateful_class.py | 15 +++++ tensorneat/utils/tools.py | 12 ++++ 17 files changed, 224 insertions(+), 140 deletions(-) diff --git a/tensorneat/algorithm/neat/gene/base.py b/tensorneat/algorithm/neat/gene/base.py index 296a6c4..c1d89a5 100644 --- a/tensorneat/algorithm/neat/gene/base.py +++ b/tensorneat/algorithm/neat/gene/base.py @@ -1,5 +1,5 @@ import jax, jax.numpy as jnp -from utils import State, StatefulBaseClass +from utils import State, StatefulBaseClass, hash_array class BaseGene(StatefulBaseClass): @@ -43,3 +43,6 @@ class BaseGene(StatefulBaseClass): def repr(self, state, gene, precision=2): raise NotImplementedError + + def hash(self, gene): + return hash_array(gene) diff --git a/tensorneat/algorithm/neat/genome/base.py b/tensorneat/algorithm/neat/genome/base.py index aa16672..a73122d 100644 --- a/tensorneat/algorithm/neat/genome/base.py +++ b/tensorneat/algorithm/neat/genome/base.py @@ -2,7 +2,7 @@ import numpy as np import jax, jax.numpy as jnp from ..gene import BaseNodeGene, BaseConnGene from ..ga import BaseMutation, BaseCrossover -from utils import State, StatefulBaseClass, topological_sort_python +from utils import State, StatefulBaseClass, topological_sort_python, hash_array class BaseGenome(StatefulBaseClass): @@ -255,10 +255,14 @@ class BaseGenome(StatefulBaseClass): nx.draw( G, - with_labels=True, pos=rotated_pos, node_size=node_sizes, node_color=node_colors, **kwargs, ) plt.savefig(save_path, dpi=save_dpi) + + def hash(self, nodes, conns): + nodes_hashs = jax.vmap(hash_array)(nodes) + conns_hashs = jax.vmap(hash_array)(conns) + return hash_array(jnp.concatenate([nodes_hashs, conns_hashs])) diff --git a/tensorneat/algorithm/neat/genome/default.py b/tensorneat/algorithm/neat/genome/default.py index 6ccb1ec..8f61cfd 100644 --- a/tensorneat/algorithm/neat/genome/default.py +++ b/tensorneat/algorithm/neat/genome/default.py @@ -210,7 +210,14 @@ class DefaultGenome(BaseGenome): new_transformed, ) - def sympy_func(self, state, network, sympy_input_transform=None, sympy_output_transform=None, backend="jax"): + def sympy_func( + self, + state, + network, + sympy_input_transform=None, + sympy_output_transform=None, + backend="jax", + ): assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'" module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP @@ -219,6 +226,10 @@ class DefaultGenome(BaseGenome): warnings.warn( "genome.input_transform is not None but sympy_input_transform is None!" ) + + if sympy_input_transform is None: + sympy_input_transform = lambda x: x + if sympy_input_transform is not None: if not isinstance(sympy_input_transform, list): sympy_input_transform = [sympy_input_transform] * self.num_inputs @@ -231,11 +242,14 @@ class DefaultGenome(BaseGenome): input_idx = self.get_input_idx() output_idx = self.get_output_idx() order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"])) - hidden_idx = [i for i in network["nodes"] if i not in input_idx and i not in output_idx] + hidden_idx = [ + i for i in network["nodes"] if i not in input_idx and i not in output_idx + ] symbols = {} for i in network["nodes"]: if i in input_idx: - symbols[i] = sp.Symbol(f"i{i - min(input_idx)}") + symbols[-i - 1] = sp.Symbol(f"i{i - min(input_idx)}") # origin_i + symbols[i] = sp.Symbol(f"norm{i - min(input_idx)}") elif i in output_idx: symbols[i] = sp.Symbol(f"o{i - min(output_idx)}") else: # hidden @@ -246,10 +260,9 @@ class DefaultGenome(BaseGenome): for i in order: if i in input_idx: - if sympy_input_transform is not None: - nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](symbols[i]) - else: - nodes_exprs[symbols[i]] = symbols[i] + nodes_exprs[symbols[-i - 1]] = symbols[-i - 1] # origin equal to its symbol + nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](symbols[-i - 1]) # normed i + else: in_conns = [c for c in network["conns"] if c[1] == i] node_inputs = [] @@ -270,12 +283,13 @@ class DefaultGenome(BaseGenome): is_output_node=(i in output_idx), ) args_symbols.update(a_s) + if i in output_idx and sympy_output_transform is not None: nodes_exprs[symbols[i]] = sympy_output_transform( nodes_exprs[symbols[i]] ) - input_symbols = [v for k, v in symbols.items() if k in input_idx] + input_symbols = [symbols[-i - 1] for i in input_idx] reduced_exprs = nodes_exprs.copy() for i in order: reduced_exprs[symbols[i]] = reduced_exprs[symbols[i]].subs(reduced_exprs) @@ -299,7 +313,9 @@ class DefaultGenome(BaseGenome): fixed_args_output_funcs.append(f) - forward_func = lambda inputs: jnp.array([f(inputs) for f in fixed_args_output_funcs]) + forward_func = lambda inputs: jnp.array( + [f(inputs) for f in fixed_args_output_funcs] + ) return ( symbols, diff --git a/tensorneat/algorithm/neat/neat.py b/tensorneat/algorithm/neat/neat.py index 62ad5e0..c12d49d 100644 --- a/tensorneat/algorithm/neat/neat.py +++ b/tensorneat/algorithm/neat/neat.py @@ -2,8 +2,6 @@ import jax, jax.numpy as jnp from utils import State from .. import BaseAlgorithm from .species import * -from .ga import * -from .genome import * class NEAT(BaseAlgorithm): @@ -16,28 +14,13 @@ class NEAT(BaseAlgorithm): def setup(self, state=State()): state = self.species.setup(state) - state = state.register( - generation=jnp.array(0.0), - next_node_key=jnp.array( - max(*self.genome.input_idx, *self.genome.output_idx) + 2, - dtype=jnp.float32, - ), - ) return state def ask(self, state: State): return self.species.ask(state) def tell(self, state: State, fitness): - k1, k2, randkey = jax.random.split(state.randkey, 3) - - state = state.update(generation=state.generation + 1, randkey=randkey) - - state, winner, loser, elite_mask = self.species.update_species(state, fitness) - state = self.create_next_generation(state, winner, loser, elite_mask) - state = self.species.speciate(state) - - return state + return self.species.tell(state, fitness) def transform(self, state, individual): """transform the genome into a neural network""" @@ -65,50 +48,6 @@ class NEAT(BaseAlgorithm): def pop_size(self): return self.species.pop_size - def create_next_generation(self, state, winner, loser, elite_mask): - # prepare random keys - pop_size = self.species.pop_size - new_node_keys = jnp.arange(pop_size) + state.next_node_key - - k1, k2, randkey = jax.random.split(state.randkey, 3) - crossover_randkeys = jax.random.split(k1, pop_size) - mutate_randkeys = jax.random.split(k2, pop_size) - - wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner] - lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser] - - # batch crossover - n_nodes, n_conns = jax.vmap( - self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0) - )( - state, crossover_randkeys, wpn, wpc, lpn, lpc - ) # new_nodes, new_conns - - # batch mutation - m_n_nodes, m_n_conns = jax.vmap( - self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0) - )( - state, mutate_randkeys, n_nodes, n_conns, new_node_keys - ) # mutated_new_nodes, mutated_new_conns - - # elitism don't mutate - pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes) - pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns) - - # update next node key - all_nodes_keys = pop_nodes[:, :, 0] - max_node_key = jnp.max( - jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys) - ) - next_node_key = max_node_key + 1 - - return state.update( - randkey=randkey, - pop_nodes=pop_nodes, - pop_conns=pop_conns, - next_node_key=next_node_key, - ) - def member_count(self, state: State): return state.member_count diff --git a/tensorneat/algorithm/neat/species/base.py b/tensorneat/algorithm/neat/species/base.py index 4654dba..f53b8a5 100644 --- a/tensorneat/algorithm/neat/species/base.py +++ b/tensorneat/algorithm/neat/species/base.py @@ -10,6 +10,9 @@ class BaseSpecies(StatefulBaseClass): def ask(self, state: State): raise NotImplementedError + def tell(self, state: State, fitness): + raise NotImplementedError + def update_species(self, state, fitness): raise NotImplementedError diff --git a/tensorneat/algorithm/neat/species/default.py b/tensorneat/algorithm/neat/species/default.py index 64c3d9f..809a19d 100644 --- a/tensorneat/algorithm/neat/species/default.py +++ b/tensorneat/algorithm/neat/species/default.py @@ -113,12 +113,23 @@ class DefaultSpecies(BaseSpecies): idx2species=idx2species, center_nodes=center_nodes, center_conns=center_conns, - next_species_key=jnp.array(1), # 0 is reserved for the first species + next_species_key=jnp.float32(1), # 0 is reserved for the first species + generation=jnp.float32(0), ) def ask(self, state): return state.pop_nodes, state.pop_conns + def tell(self, state, fitness): + k1, k2, randkey = jax.random.split(state.randkey, 3) + + state = state.update(generation=state.generation + 1, randkey=randkey) + state, winner, loser, elite_mask = self.update_species(state, fitness) + state = self.create_next_generation(state, winner, loser, elite_mask) + state = self.speciate(state) + + return state + def update_species(self, state, fitness): # update the fitness of each species state, species_fitness = self.update_species_fitness(state, fitness) @@ -619,3 +630,43 @@ class DefaultSpecies(BaseSpecies): val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize return val + + def create_next_generation(self, state, winner, loser, elite_mask): + + # find next node key + all_nodes_keys = state.pop_nodes[:, :, 0] + max_node_key = jnp.max(all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0) + next_node_key = max_node_key + 1 + new_node_keys = jnp.arange(self.pop_size) + next_node_key + + # prepare random keys + k1, k2, randkey = jax.random.split(state.randkey, 3) + crossover_randkeys = jax.random.split(k1, self.pop_size) + mutate_randkeys = jax.random.split(k2, self.pop_size) + + wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner] + lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser] + + # batch crossover + n_nodes, n_conns = jax.vmap( + self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0) + )( + state, crossover_randkeys, wpn, wpc, lpn, lpc + ) # new_nodes, new_conns + + # batch mutation + m_n_nodes, m_n_conns = jax.vmap( + self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0) + )( + state, mutate_randkeys, n_nodes, n_conns, new_node_keys + ) # mutated_new_nodes, mutated_new_conns + + # elitism don't mutate + pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes) + pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns) + + return state.update( + randkey=randkey, + pop_nodes=pop_nodes, + pop_conns=pop_conns, + ) \ No newline at end of file diff --git a/tensorneat/examples/func_fit/xor.py b/tensorneat/examples/func_fit/xor.py index e6038cd..e93326f 100644 --- a/tensorneat/examples/func_fit/xor.py +++ b/tensorneat/examples/func_fit/xor.py @@ -8,7 +8,7 @@ if __name__ == "__main__": pipeline = Pipeline( algorithm=NEAT( species=DefaultSpecies( - genome=DefaultGenome( + genome=DenseInitialize( num_inputs=3, num_outputs=1, max_nodes=50, @@ -21,7 +21,7 @@ if __name__ == "__main__": # aggregation_options=(Agg.sum,), aggregation_options=AGG_ALL, ), - output_transform=Act.sigmoid, # the activation function for output node + output_transform=Act.standard_sigmoid, # the activation function for output node mutation=DefaultMutation( node_add=0.1, conn_add=0.1, @@ -29,7 +29,7 @@ if __name__ == "__main__": conn_delete=0, ), ), - pop_size=100000, + pop_size=10000, species_size=20, compatibility_threshold=2, survival_threshold=0.01, # magic @@ -37,7 +37,7 @@ if __name__ == "__main__": ), problem=XOR3d(), generation_limit=10000, - fitness_target=-1e-8, + fitness_target=-1e-3, ) # initialize state diff --git a/tensorneat/examples/gymnax/cartpole.py b/tensorneat/examples/gymnax/cartpole.py index 1368ee8..a479a2a 100644 --- a/tensorneat/examples/gymnax/cartpole.py +++ b/tensorneat/examples/gymnax/cartpole.py @@ -6,7 +6,7 @@ from algorithm.neat import * from problem.rl_env import GymNaxEnv -def action_policy(forward_func, obs): +def action_policy(randkey, forward_func, obs): return jnp.argmax(forward_func(obs)) @@ -27,7 +27,9 @@ if __name__ == "__main__": species_size=10, ), ), - problem=GymNaxEnv(env_name="CartPole-v1", repeat_times=5, action_policy=action_policy), + problem=GymNaxEnv( + env_name="CartPole-v1", repeat_times=5, action_policy=action_policy + ), generation_limit=10000, fitness_target=500, ) diff --git a/tensorneat/examples/interpret_visualize/genome_sympy.py b/tensorneat/examples/interpret_visualize/genome_sympy.py index d050221..0b8363c 100644 --- a/tensorneat/examples/interpret_visualize/genome_sympy.py +++ b/tensorneat/examples/interpret_visualize/genome_sympy.py @@ -1,14 +1,14 @@ import jax, jax.numpy as jnp from algorithm.neat import * -from algorithm.neat.genome.hidden import AdvanceInitialize +from algorithm.neat.genome.dense import DenseInitialize from utils.graph import topological_sort_python +from utils import * -if __name__ == '__main__': - genome = AdvanceInitialize( - num_inputs=17, - num_outputs=6, - hidden_cnt=8, +if __name__ == "__main__": + genome = DenseInitialize( + num_inputs=3, + num_outputs=1, max_nodes=50, max_conns=500, ) @@ -19,16 +19,19 @@ if __name__ == '__main__': nodes, conns = genome.initialize(state, randkey) network = genome.network_dict(state, nodes, conns) - print(set(network["nodes"]), set(network["conns"])) - order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"])) - print(order) input_idx, output_idx = genome.get_input_idx(), genome.get_output_idx() - print(input_idx, output_idx) - print(genome.repr(state, nodes, conns)) - print(network) + res = genome.sympy_func(state, network, sympy_input_transform=lambda x: 999999999*x, sympy_output_transform=SympyStandardSigmoid) + (symbols, + args_symbols, + input_symbols, + nodes_exprs, + output_exprs, + forward_func,) = res - res = genome.sympy_func(state, network, precision=3) - print(res) + print(symbols) + print(output_exprs[0].subs(args_symbols)) + inputs = jnp.zeros(3) + print(forward_func(inputs)) diff --git a/tensorneat/pipeline.py b/tensorneat/pipeline.py index 2299df2..53d2b6b 100644 --- a/tensorneat/pipeline.py +++ b/tensorneat/pipeline.py @@ -71,6 +71,9 @@ class Pipeline(StatefulBaseClass): print(f"save to {self.save_dir}") if not os.path.exists(self.save_dir): os.makedirs(self.save_dir) + self.genome_dir = os.path.join(self.save_dir, "genomes") + if not os.path.exists(self.genome_dir): + os.makedirs(self.genome_dir) def setup(self, state=State()): print("initializing") @@ -165,6 +168,7 @@ class Pipeline(StatefulBaseClass): print("start compile") tic = time.time() compiled_step = jax.jit(self.step).lower(state).compile() + # compiled_step = self.step print( f"compile finished, cost time: {time.time() - tic:.6f}s", ) @@ -181,9 +185,21 @@ class Pipeline(StatefulBaseClass): if max(fitnesses) >= self.fitness_target: print("Fitness limit reached!") - return state, self.best_genome + break + + if self.algorithm.generation(state) >= self.generation_limit: + print("Generation limit reached!") + + if self.is_save: + best_genome = jax.device_get(self.best_genome) + with open(os.path.join(self.genome_dir, f"best_genome.npz"), "wb") as f: + np.savez( + f, + nodes=best_genome[0], + conns=best_genome[1], + fitness=self.best_fitness, + ) - print("Generation limit reached!") return state, self.best_genome def analysis(self, state, pop, fitnesses): @@ -206,15 +222,15 @@ class Pipeline(StatefulBaseClass): self.best_fitness = fitnesses[max_idx] self.best_genome = pop[0][max_idx], pop[1][max_idx] - if self.is_save: - best_genome = jax.device_get(self.best_genome) - with open(os.path.join(self.save_dir, f"{int(self.algorithm.generation(state))}.npz"), "wb") as f: - np.savez( - f, - nodes=best_genome[0], - conns=best_genome[1], - fitness=self.best_fitness, - ) + if self.is_save: + best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx])) + with open(os.path.join(self.genome_dir, f"{int(self.algorithm.generation(state))}.npz"), "wb") as f: + np.savez( + f, + nodes=best_genome[0], + conns=best_genome[1], + fitness=self.best_fitness, + ) # save best if save path is not None diff --git a/tensorneat/utils/__init__.py b/tensorneat/utils/__init__.py index fd96b5a..f61b9bd 100644 --- a/tensorneat/utils/__init__.py +++ b/tensorneat/utils/__init__.py @@ -1,5 +1,3 @@ -import jax.numpy as jnp - from utils.aggregation.agg_jnp import Agg, agg_func, AGG_ALL from .tools import * from .graph import * @@ -15,7 +13,9 @@ from typing import Union name2sympy = { "sigmoid": SympySigmoid, + "standard_sigmoid": SympyStandardSigmoid, "tanh": SympyTanh, + "standard_tanh": SympyStandardTanh, "sin": SympySin, "relu": SympyRelu, "lelu": SympyLelu, diff --git a/tensorneat/utils/activation/act_jnp.py b/tensorneat/utils/activation/act_jnp.py index 218ed3c..2b042d4 100644 --- a/tensorneat/utils/activation/act_jnp.py +++ b/tensorneat/utils/activation/act_jnp.py @@ -12,19 +12,26 @@ class Act: @staticmethod def sigmoid(z): - z = jnp.clip(5 * z / sigma_3, -5, 5) + z = 5 * z / sigma_3 z = 1 / (1 + jnp.exp(-z)) return z * sigma_3 # (0, sigma_3) + @staticmethod + def standard_sigmoid(z): + z = 5 * z / sigma_3 + z = 1 / (1 + jnp.exp(-z)) + + return z # (0, 1) + @staticmethod def tanh(z): - z = jnp.clip(5 * z / sigma_3, -5, 5) + z = 5 * z / sigma_3 return jnp.tanh(z) * sigma_3 # (-sigma_3, sigma_3) @staticmethod def standard_tanh(z): - z = jnp.clip(5 * z / sigma_3, -5, 5) + z =5 * z / sigma_3 return jnp.tanh(z) # (-1, 1) @staticmethod diff --git a/tensorneat/utils/activation/act_sympy.py b/tensorneat/utils/activation/act_sympy.py index ae5f3ae..fc971ca 100644 --- a/tensorneat/utils/activation/act_sympy.py +++ b/tensorneat/utils/activation/act_sympy.py @@ -25,21 +25,16 @@ class SympyClip(sp.Function): return rf"\mathrm{{clip}}\left({sp.latex(self.args[0])}, {self.args[1]}, {self.args[2]}\right)" -class SympySigmoid(sp.Function): +class SympySigmoid_(sp.Function): @classmethod def eval(cls, z): - if z.is_Number: - z = SympyClip(5 * z / sigma_3, -5, 5) - z = 1 / (1 + sp.exp(-z)) - return z * sigma_3 - return None + z = 1 / (1 + sp.exp(-z)) + return z @staticmethod def numerical_eval(z, backend=np): - z = backend.clip(5 * z / sigma_3, -5, 5) z = 1 / (1 + backend.exp(-z)) - - return z * sigma_3 # (0, sigma_3) + return z def _sympystr(self, printer): return f"sigmoid({self.args[0]})" @@ -48,32 +43,47 @@ class SympySigmoid(sp.Function): return rf"\mathrm{{sigmoid}}\left({sp.latex(self.args[0])}\right)" +class SympySigmoid(sp.Function): + @classmethod + def eval(cls, z): + return SympySigmoid_(5 * z / sigma_3) * sigma_3 + + +class SympyStandardSigmoid(sp.Function): + @classmethod + def eval(cls, z): + return SympySigmoid_(5 * z / sigma_3) + + # @staticmethod + # def numerical_eval(z, backend=np): + # z = backend.clip(5 * z / sigma_3, -5, 5) + # z = 1 / (1 + backend.exp(-z)) + # + # return z # (0, 1) + + class SympyTanh(sp.Function): @classmethod def eval(cls, z): - if z.is_Number: - z = SympyClip(5 * z / sigma_3, -5, 5) - return sp.tanh(z) * sigma_3 - return None + z = 5 * z / sigma_3 + return sp.tanh(z) * sigma_3 - @staticmethod - def numerical_eval(z, backend=np): - z = backend.clip(5 * z / sigma_3, -5, 5) - return backend.tanh(z) * sigma_3 # (-sigma_3, sigma_3) + # @staticmethod + # def numerical_eval(z, backend=np): + # z = backend.clip(5 * z / sigma_3, -5, 5) + # return backend.tanh(z) * sigma_3 # (-sigma_3, sigma_3) class SympyStandardTanh(sp.Function): @classmethod def eval(cls, z): - if z.is_Number: - z = SympyClip(5 * z / sigma_3, -5, 5) - return sp.tanh(z) - return None + z = 5 * z / sigma_3 + return sp.tanh(z) - @staticmethod - def numerical_eval(z, backend=np): - z = backend.clip(5 * z / sigma_3, -5, 5) - return backend.tanh(z) # (-1, 1) + # @staticmethod + # def numerical_eval(z, backend=np): + # z = backend.clip(5 * z / sigma_3, -5, 5) + # return backend.tanh(z) # (-1, 1) class SympySin(sp.Function): diff --git a/tensorneat/utils/aggregation/agg_jnp.py b/tensorneat/utils/aggregation/agg_jnp.py index 3359c6b..1be1ef0 100644 --- a/tensorneat/utils/aggregation/agg_jnp.py +++ b/tensorneat/utils/aggregation/agg_jnp.py @@ -9,19 +9,19 @@ class Agg: @staticmethod def sum(z): - return jnp.sum(z, axis=0, where=~jnp.isnan(z)) + return jnp.sum(z, axis=0, where=~jnp.isnan(z), initial=0) @staticmethod def product(z): - return jnp.prod(z, axis=0, where=~jnp.isnan(z)) + return jnp.prod(z, axis=0, where=~jnp.isnan(z), initial=1) @staticmethod def max(z): - return jnp.max(z, axis=0, where=~jnp.isnan(z)) + return jnp.max(z, axis=0, where=~jnp.isnan(z), initial=-jnp.inf) @staticmethod def min(z): - return jnp.min(z, axis=0, where=~jnp.isnan(z)) + return jnp.min(z, axis=0, where=~jnp.isnan(z), initial=jnp.inf) @staticmethod def maxabs(z): diff --git a/tensorneat/utils/state.py b/tensorneat/utils/state.py index bac0a11..36bd165 100644 --- a/tensorneat/utils/state.py +++ b/tensorneat/utils/state.py @@ -36,6 +36,9 @@ class State: def __setstate__(self, state): self.__dict__["state_dict"] = state + def __contains__(self, item): + return item in self.state_dict + def tree_flatten(self): children = list(self.state_dict.values()) aux_data = list(self.state_dict.keys()) diff --git a/tensorneat/utils/stateful_class.py b/tensorneat/utils/stateful_class.py index 41ce2cc..7646493 100644 --- a/tensorneat/utils/stateful_class.py +++ b/tensorneat/utils/stateful_class.py @@ -19,6 +19,21 @@ class StatefulBaseClass: with open(path, "wb") as f: pickle.dump(self, f) + def __getstate__(self): + # only pickle the picklable attributes + state = self.__dict__.copy() + non_picklable_keys = [] + for key, value in state.items(): + try: + pickle.dumps(value) + except Exception: + non_picklable_keys.append(key) + + for key in non_picklable_keys: + state.pop(key) + + return state + def show_config(self): config = {} for key, value in self.__dict__.items(): diff --git a/tensorneat/utils/tools.py b/tensorneat/utils/tools.py index 9eceb51..d1f3f22 100644 --- a/tensorneat/utils/tools.py +++ b/tensorneat/utils/tools.py @@ -36,6 +36,7 @@ def unflatten_conns(nodes, conns): return unflatten +# TODO: strange implementation def attach_with_inf(arr, idx): expand_size = arr.ndim - idx.ndim expand_idx = jnp.expand_dims( @@ -199,3 +200,14 @@ def delete_conn_by_pos(conns, pos): Delete the connection by its idx. """ return conns.at[pos].set(jnp.nan) + + +def hash_array(arr: Array): + arr = jax.lax.bitcast_convert_type(arr, jnp.uint32) + + def update(i, hash_val): + return hash_val ^ ( + arr[i] + jnp.uint32(0x9E3779B9) + (hash_val << 6) + (hash_val >> 2) + ) + + return jax.lax.fori_loop(0, arr.size, update, jnp.uint32(0))