This commit is contained in:
wls2002
2024-06-20 16:32:52 +08:00
parent 9f72813c35
commit 075460f896
17 changed files with 224 additions and 140 deletions

View File

@@ -1,5 +1,5 @@
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from utils import State, StatefulBaseClass from utils import State, StatefulBaseClass, hash_array
class BaseGene(StatefulBaseClass): class BaseGene(StatefulBaseClass):
@@ -43,3 +43,6 @@ class BaseGene(StatefulBaseClass):
def repr(self, state, gene, precision=2): def repr(self, state, gene, precision=2):
raise NotImplementedError raise NotImplementedError
def hash(self, gene):
return hash_array(gene)

View File

@@ -2,7 +2,7 @@ import numpy as np
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from ..gene import BaseNodeGene, BaseConnGene from ..gene import BaseNodeGene, BaseConnGene
from ..ga import BaseMutation, BaseCrossover from ..ga import BaseMutation, BaseCrossover
from utils import State, StatefulBaseClass, topological_sort_python from utils import State, StatefulBaseClass, topological_sort_python, hash_array
class BaseGenome(StatefulBaseClass): class BaseGenome(StatefulBaseClass):
@@ -255,10 +255,14 @@ class BaseGenome(StatefulBaseClass):
nx.draw( nx.draw(
G, G,
with_labels=True,
pos=rotated_pos, pos=rotated_pos,
node_size=node_sizes, node_size=node_sizes,
node_color=node_colors, node_color=node_colors,
**kwargs, **kwargs,
) )
plt.savefig(save_path, dpi=save_dpi) plt.savefig(save_path, dpi=save_dpi)
def hash(self, nodes, conns):
nodes_hashs = jax.vmap(hash_array)(nodes)
conns_hashs = jax.vmap(hash_array)(conns)
return hash_array(jnp.concatenate([nodes_hashs, conns_hashs]))

View File

@@ -210,7 +210,14 @@ class DefaultGenome(BaseGenome):
new_transformed, new_transformed,
) )
def sympy_func(self, state, network, sympy_input_transform=None, sympy_output_transform=None, backend="jax"): def sympy_func(
self,
state,
network,
sympy_input_transform=None,
sympy_output_transform=None,
backend="jax",
):
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'" assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP
@@ -219,6 +226,10 @@ class DefaultGenome(BaseGenome):
warnings.warn( warnings.warn(
"genome.input_transform is not None but sympy_input_transform is None!" "genome.input_transform is not None but sympy_input_transform is None!"
) )
if sympy_input_transform is None:
sympy_input_transform = lambda x: x
if sympy_input_transform is not None: if sympy_input_transform is not None:
if not isinstance(sympy_input_transform, list): if not isinstance(sympy_input_transform, list):
sympy_input_transform = [sympy_input_transform] * self.num_inputs sympy_input_transform = [sympy_input_transform] * self.num_inputs
@@ -231,11 +242,14 @@ class DefaultGenome(BaseGenome):
input_idx = self.get_input_idx() input_idx = self.get_input_idx()
output_idx = self.get_output_idx() output_idx = self.get_output_idx()
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"])) order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
hidden_idx = [i for i in network["nodes"] if i not in input_idx and i not in output_idx] hidden_idx = [
i for i in network["nodes"] if i not in input_idx and i not in output_idx
]
symbols = {} symbols = {}
for i in network["nodes"]: for i in network["nodes"]:
if i in input_idx: if i in input_idx:
symbols[i] = sp.Symbol(f"i{i - min(input_idx)}") symbols[-i - 1] = sp.Symbol(f"i{i - min(input_idx)}") # origin_i
symbols[i] = sp.Symbol(f"norm{i - min(input_idx)}")
elif i in output_idx: elif i in output_idx:
symbols[i] = sp.Symbol(f"o{i - min(output_idx)}") symbols[i] = sp.Symbol(f"o{i - min(output_idx)}")
else: # hidden else: # hidden
@@ -246,10 +260,9 @@ class DefaultGenome(BaseGenome):
for i in order: for i in order:
if i in input_idx: if i in input_idx:
if sympy_input_transform is not None: nodes_exprs[symbols[-i - 1]] = symbols[-i - 1] # origin equal to its symbol
nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](symbols[i]) nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](symbols[-i - 1]) # normed i
else:
nodes_exprs[symbols[i]] = symbols[i]
else: else:
in_conns = [c for c in network["conns"] if c[1] == i] in_conns = [c for c in network["conns"] if c[1] == i]
node_inputs = [] node_inputs = []
@@ -270,12 +283,13 @@ class DefaultGenome(BaseGenome):
is_output_node=(i in output_idx), is_output_node=(i in output_idx),
) )
args_symbols.update(a_s) args_symbols.update(a_s)
if i in output_idx and sympy_output_transform is not None: if i in output_idx and sympy_output_transform is not None:
nodes_exprs[symbols[i]] = sympy_output_transform( nodes_exprs[symbols[i]] = sympy_output_transform(
nodes_exprs[symbols[i]] nodes_exprs[symbols[i]]
) )
input_symbols = [v for k, v in symbols.items() if k in input_idx] input_symbols = [symbols[-i - 1] for i in input_idx]
reduced_exprs = nodes_exprs.copy() reduced_exprs = nodes_exprs.copy()
for i in order: for i in order:
reduced_exprs[symbols[i]] = reduced_exprs[symbols[i]].subs(reduced_exprs) reduced_exprs[symbols[i]] = reduced_exprs[symbols[i]].subs(reduced_exprs)
@@ -299,7 +313,9 @@ class DefaultGenome(BaseGenome):
fixed_args_output_funcs.append(f) fixed_args_output_funcs.append(f)
forward_func = lambda inputs: jnp.array([f(inputs) for f in fixed_args_output_funcs]) forward_func = lambda inputs: jnp.array(
[f(inputs) for f in fixed_args_output_funcs]
)
return ( return (
symbols, symbols,

View File

@@ -2,8 +2,6 @@ import jax, jax.numpy as jnp
from utils import State from utils import State
from .. import BaseAlgorithm from .. import BaseAlgorithm
from .species import * from .species import *
from .ga import *
from .genome import *
class NEAT(BaseAlgorithm): class NEAT(BaseAlgorithm):
@@ -16,28 +14,13 @@ class NEAT(BaseAlgorithm):
def setup(self, state=State()): def setup(self, state=State()):
state = self.species.setup(state) state = self.species.setup(state)
state = state.register(
generation=jnp.array(0.0),
next_node_key=jnp.array(
max(*self.genome.input_idx, *self.genome.output_idx) + 2,
dtype=jnp.float32,
),
)
return state return state
def ask(self, state: State): def ask(self, state: State):
return self.species.ask(state) return self.species.ask(state)
def tell(self, state: State, fitness): def tell(self, state: State, fitness):
k1, k2, randkey = jax.random.split(state.randkey, 3) return self.species.tell(state, fitness)
state = state.update(generation=state.generation + 1, randkey=randkey)
state, winner, loser, elite_mask = self.species.update_species(state, fitness)
state = self.create_next_generation(state, winner, loser, elite_mask)
state = self.species.speciate(state)
return state
def transform(self, state, individual): def transform(self, state, individual):
"""transform the genome into a neural network""" """transform the genome into a neural network"""
@@ -65,50 +48,6 @@ class NEAT(BaseAlgorithm):
def pop_size(self): def pop_size(self):
return self.species.pop_size return self.species.pop_size
def create_next_generation(self, state, winner, loser, elite_mask):
# prepare random keys
pop_size = self.species.pop_size
new_node_keys = jnp.arange(pop_size) + state.next_node_key
k1, k2, randkey = jax.random.split(state.randkey, 3)
crossover_randkeys = jax.random.split(k1, pop_size)
mutate_randkeys = jax.random.split(k2, pop_size)
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
# batch crossover
n_nodes, n_conns = jax.vmap(
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
)(
state, crossover_randkeys, wpn, wpc, lpn, lpc
) # new_nodes, new_conns
# batch mutation
m_n_nodes, m_n_conns = jax.vmap(
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
)(
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
) # mutated_new_nodes, mutated_new_conns
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
# update next node key
all_nodes_keys = pop_nodes[:, :, 0]
max_node_key = jnp.max(
jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys)
)
next_node_key = max_node_key + 1
return state.update(
randkey=randkey,
pop_nodes=pop_nodes,
pop_conns=pop_conns,
next_node_key=next_node_key,
)
def member_count(self, state: State): def member_count(self, state: State):
return state.member_count return state.member_count

View File

@@ -10,6 +10,9 @@ class BaseSpecies(StatefulBaseClass):
def ask(self, state: State): def ask(self, state: State):
raise NotImplementedError raise NotImplementedError
def tell(self, state: State, fitness):
raise NotImplementedError
def update_species(self, state, fitness): def update_species(self, state, fitness):
raise NotImplementedError raise NotImplementedError

View File

@@ -113,12 +113,23 @@ class DefaultSpecies(BaseSpecies):
idx2species=idx2species, idx2species=idx2species,
center_nodes=center_nodes, center_nodes=center_nodes,
center_conns=center_conns, center_conns=center_conns,
next_species_key=jnp.array(1), # 0 is reserved for the first species next_species_key=jnp.float32(1), # 0 is reserved for the first species
generation=jnp.float32(0),
) )
def ask(self, state): def ask(self, state):
return state.pop_nodes, state.pop_conns return state.pop_nodes, state.pop_conns
def tell(self, state, fitness):
k1, k2, randkey = jax.random.split(state.randkey, 3)
state = state.update(generation=state.generation + 1, randkey=randkey)
state, winner, loser, elite_mask = self.update_species(state, fitness)
state = self.create_next_generation(state, winner, loser, elite_mask)
state = self.speciate(state)
return state
def update_species(self, state, fitness): def update_species(self, state, fitness):
# update the fitness of each species # update the fitness of each species
state, species_fitness = self.update_species_fitness(state, fitness) state, species_fitness = self.update_species_fitness(state, fitness)
@@ -619,3 +630,43 @@ class DefaultSpecies(BaseSpecies):
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
return val return val
def create_next_generation(self, state, winner, loser, elite_mask):
# find next node key
all_nodes_keys = state.pop_nodes[:, :, 0]
max_node_key = jnp.max(all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0)
next_node_key = max_node_key + 1
new_node_keys = jnp.arange(self.pop_size) + next_node_key
# prepare random keys
k1, k2, randkey = jax.random.split(state.randkey, 3)
crossover_randkeys = jax.random.split(k1, self.pop_size)
mutate_randkeys = jax.random.split(k2, self.pop_size)
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
# batch crossover
n_nodes, n_conns = jax.vmap(
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
)(
state, crossover_randkeys, wpn, wpc, lpn, lpc
) # new_nodes, new_conns
# batch mutation
m_n_nodes, m_n_conns = jax.vmap(
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
)(
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
) # mutated_new_nodes, mutated_new_conns
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
return state.update(
randkey=randkey,
pop_nodes=pop_nodes,
pop_conns=pop_conns,
)

View File

@@ -8,7 +8,7 @@ if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(
algorithm=NEAT( algorithm=NEAT(
species=DefaultSpecies( species=DefaultSpecies(
genome=DefaultGenome( genome=DenseInitialize(
num_inputs=3, num_inputs=3,
num_outputs=1, num_outputs=1,
max_nodes=50, max_nodes=50,
@@ -21,7 +21,7 @@ if __name__ == "__main__":
# aggregation_options=(Agg.sum,), # aggregation_options=(Agg.sum,),
aggregation_options=AGG_ALL, aggregation_options=AGG_ALL,
), ),
output_transform=Act.sigmoid, # the activation function for output node output_transform=Act.standard_sigmoid, # the activation function for output node
mutation=DefaultMutation( mutation=DefaultMutation(
node_add=0.1, node_add=0.1,
conn_add=0.1, conn_add=0.1,
@@ -29,7 +29,7 @@ if __name__ == "__main__":
conn_delete=0, conn_delete=0,
), ),
), ),
pop_size=100000, pop_size=10000,
species_size=20, species_size=20,
compatibility_threshold=2, compatibility_threshold=2,
survival_threshold=0.01, # magic survival_threshold=0.01, # magic
@@ -37,7 +37,7 @@ if __name__ == "__main__":
), ),
problem=XOR3d(), problem=XOR3d(),
generation_limit=10000, generation_limit=10000,
fitness_target=-1e-8, fitness_target=-1e-3,
) )
# initialize state # initialize state

View File

@@ -6,7 +6,7 @@ from algorithm.neat import *
from problem.rl_env import GymNaxEnv from problem.rl_env import GymNaxEnv
def action_policy(forward_func, obs): def action_policy(randkey, forward_func, obs):
return jnp.argmax(forward_func(obs)) return jnp.argmax(forward_func(obs))
@@ -27,7 +27,9 @@ if __name__ == "__main__":
species_size=10, species_size=10,
), ),
), ),
problem=GymNaxEnv(env_name="CartPole-v1", repeat_times=5, action_policy=action_policy), problem=GymNaxEnv(
env_name="CartPole-v1", repeat_times=5, action_policy=action_policy
),
generation_limit=10000, generation_limit=10000,
fitness_target=500, fitness_target=500,
) )

View File

@@ -1,14 +1,14 @@
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from algorithm.neat import * from algorithm.neat import *
from algorithm.neat.genome.hidden import AdvanceInitialize from algorithm.neat.genome.dense import DenseInitialize
from utils.graph import topological_sort_python from utils.graph import topological_sort_python
from utils import *
if __name__ == '__main__': if __name__ == "__main__":
genome = AdvanceInitialize( genome = DenseInitialize(
num_inputs=17, num_inputs=3,
num_outputs=6, num_outputs=1,
hidden_cnt=8,
max_nodes=50, max_nodes=50,
max_conns=500, max_conns=500,
) )
@@ -19,16 +19,19 @@ if __name__ == '__main__':
nodes, conns = genome.initialize(state, randkey) nodes, conns = genome.initialize(state, randkey)
network = genome.network_dict(state, nodes, conns) network = genome.network_dict(state, nodes, conns)
print(set(network["nodes"]), set(network["conns"]))
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
print(order)
input_idx, output_idx = genome.get_input_idx(), genome.get_output_idx() input_idx, output_idx = genome.get_input_idx(), genome.get_output_idx()
print(input_idx, output_idx)
print(genome.repr(state, nodes, conns)) res = genome.sympy_func(state, network, sympy_input_transform=lambda x: 999999999*x, sympy_output_transform=SympyStandardSigmoid)
print(network) (symbols,
args_symbols,
input_symbols,
nodes_exprs,
output_exprs,
forward_func,) = res
res = genome.sympy_func(state, network, precision=3) print(symbols)
print(res) print(output_exprs[0].subs(args_symbols))
inputs = jnp.zeros(3)
print(forward_func(inputs))

View File

@@ -71,6 +71,9 @@ class Pipeline(StatefulBaseClass):
print(f"save to {self.save_dir}") print(f"save to {self.save_dir}")
if not os.path.exists(self.save_dir): if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir) os.makedirs(self.save_dir)
self.genome_dir = os.path.join(self.save_dir, "genomes")
if not os.path.exists(self.genome_dir):
os.makedirs(self.genome_dir)
def setup(self, state=State()): def setup(self, state=State()):
print("initializing") print("initializing")
@@ -165,6 +168,7 @@ class Pipeline(StatefulBaseClass):
print("start compile") print("start compile")
tic = time.time() tic = time.time()
compiled_step = jax.jit(self.step).lower(state).compile() compiled_step = jax.jit(self.step).lower(state).compile()
# compiled_step = self.step
print( print(
f"compile finished, cost time: {time.time() - tic:.6f}s", f"compile finished, cost time: {time.time() - tic:.6f}s",
) )
@@ -181,9 +185,21 @@ class Pipeline(StatefulBaseClass):
if max(fitnesses) >= self.fitness_target: if max(fitnesses) >= self.fitness_target:
print("Fitness limit reached!") print("Fitness limit reached!")
return state, self.best_genome break
if self.algorithm.generation(state) >= self.generation_limit:
print("Generation limit reached!")
if self.is_save:
best_genome = jax.device_get(self.best_genome)
with open(os.path.join(self.genome_dir, f"best_genome.npz"), "wb") as f:
np.savez(
f,
nodes=best_genome[0],
conns=best_genome[1],
fitness=self.best_fitness,
)
print("Generation limit reached!")
return state, self.best_genome return state, self.best_genome
def analysis(self, state, pop, fitnesses): def analysis(self, state, pop, fitnesses):
@@ -206,15 +222,15 @@ class Pipeline(StatefulBaseClass):
self.best_fitness = fitnesses[max_idx] self.best_fitness = fitnesses[max_idx]
self.best_genome = pop[0][max_idx], pop[1][max_idx] self.best_genome = pop[0][max_idx], pop[1][max_idx]
if self.is_save: if self.is_save:
best_genome = jax.device_get(self.best_genome) best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx]))
with open(os.path.join(self.save_dir, f"{int(self.algorithm.generation(state))}.npz"), "wb") as f: with open(os.path.join(self.genome_dir, f"{int(self.algorithm.generation(state))}.npz"), "wb") as f:
np.savez( np.savez(
f, f,
nodes=best_genome[0], nodes=best_genome[0],
conns=best_genome[1], conns=best_genome[1],
fitness=self.best_fitness, fitness=self.best_fitness,
) )
# save best if save path is not None # save best if save path is not None

View File

@@ -1,5 +1,3 @@
import jax.numpy as jnp
from utils.aggregation.agg_jnp import Agg, agg_func, AGG_ALL from utils.aggregation.agg_jnp import Agg, agg_func, AGG_ALL
from .tools import * from .tools import *
from .graph import * from .graph import *
@@ -15,7 +13,9 @@ from typing import Union
name2sympy = { name2sympy = {
"sigmoid": SympySigmoid, "sigmoid": SympySigmoid,
"standard_sigmoid": SympyStandardSigmoid,
"tanh": SympyTanh, "tanh": SympyTanh,
"standard_tanh": SympyStandardTanh,
"sin": SympySin, "sin": SympySin,
"relu": SympyRelu, "relu": SympyRelu,
"lelu": SympyLelu, "lelu": SympyLelu,

View File

@@ -12,19 +12,26 @@ class Act:
@staticmethod @staticmethod
def sigmoid(z): def sigmoid(z):
z = jnp.clip(5 * z / sigma_3, -5, 5) z = 5 * z / sigma_3
z = 1 / (1 + jnp.exp(-z)) z = 1 / (1 + jnp.exp(-z))
return z * sigma_3 # (0, sigma_3) return z * sigma_3 # (0, sigma_3)
@staticmethod
def standard_sigmoid(z):
z = 5 * z / sigma_3
z = 1 / (1 + jnp.exp(-z))
return z # (0, 1)
@staticmethod @staticmethod
def tanh(z): def tanh(z):
z = jnp.clip(5 * z / sigma_3, -5, 5) z = 5 * z / sigma_3
return jnp.tanh(z) * sigma_3 # (-sigma_3, sigma_3) return jnp.tanh(z) * sigma_3 # (-sigma_3, sigma_3)
@staticmethod @staticmethod
def standard_tanh(z): def standard_tanh(z):
z = jnp.clip(5 * z / sigma_3, -5, 5) z =5 * z / sigma_3
return jnp.tanh(z) # (-1, 1) return jnp.tanh(z) # (-1, 1)
@staticmethod @staticmethod

View File

@@ -25,21 +25,16 @@ class SympyClip(sp.Function):
return rf"\mathrm{{clip}}\left({sp.latex(self.args[0])}, {self.args[1]}, {self.args[2]}\right)" return rf"\mathrm{{clip}}\left({sp.latex(self.args[0])}, {self.args[1]}, {self.args[2]}\right)"
class SympySigmoid(sp.Function): class SympySigmoid_(sp.Function):
@classmethod @classmethod
def eval(cls, z): def eval(cls, z):
if z.is_Number: z = 1 / (1 + sp.exp(-z))
z = SympyClip(5 * z / sigma_3, -5, 5) return z
z = 1 / (1 + sp.exp(-z))
return z * sigma_3
return None
@staticmethod @staticmethod
def numerical_eval(z, backend=np): def numerical_eval(z, backend=np):
z = backend.clip(5 * z / sigma_3, -5, 5)
z = 1 / (1 + backend.exp(-z)) z = 1 / (1 + backend.exp(-z))
return z
return z * sigma_3 # (0, sigma_3)
def _sympystr(self, printer): def _sympystr(self, printer):
return f"sigmoid({self.args[0]})" return f"sigmoid({self.args[0]})"
@@ -48,32 +43,47 @@ class SympySigmoid(sp.Function):
return rf"\mathrm{{sigmoid}}\left({sp.latex(self.args[0])}\right)" return rf"\mathrm{{sigmoid}}\left({sp.latex(self.args[0])}\right)"
class SympySigmoid(sp.Function):
@classmethod
def eval(cls, z):
return SympySigmoid_(5 * z / sigma_3) * sigma_3
class SympyStandardSigmoid(sp.Function):
@classmethod
def eval(cls, z):
return SympySigmoid_(5 * z / sigma_3)
# @staticmethod
# def numerical_eval(z, backend=np):
# z = backend.clip(5 * z / sigma_3, -5, 5)
# z = 1 / (1 + backend.exp(-z))
#
# return z # (0, 1)
class SympyTanh(sp.Function): class SympyTanh(sp.Function):
@classmethod @classmethod
def eval(cls, z): def eval(cls, z):
if z.is_Number: z = 5 * z / sigma_3
z = SympyClip(5 * z / sigma_3, -5, 5) return sp.tanh(z) * sigma_3
return sp.tanh(z) * sigma_3
return None
@staticmethod # @staticmethod
def numerical_eval(z, backend=np): # def numerical_eval(z, backend=np):
z = backend.clip(5 * z / sigma_3, -5, 5) # z = backend.clip(5 * z / sigma_3, -5, 5)
return backend.tanh(z) * sigma_3 # (-sigma_3, sigma_3) # return backend.tanh(z) * sigma_3 # (-sigma_3, sigma_3)
class SympyStandardTanh(sp.Function): class SympyStandardTanh(sp.Function):
@classmethod @classmethod
def eval(cls, z): def eval(cls, z):
if z.is_Number: z = 5 * z / sigma_3
z = SympyClip(5 * z / sigma_3, -5, 5) return sp.tanh(z)
return sp.tanh(z)
return None
@staticmethod # @staticmethod
def numerical_eval(z, backend=np): # def numerical_eval(z, backend=np):
z = backend.clip(5 * z / sigma_3, -5, 5) # z = backend.clip(5 * z / sigma_3, -5, 5)
return backend.tanh(z) # (-1, 1) # return backend.tanh(z) # (-1, 1)
class SympySin(sp.Function): class SympySin(sp.Function):

View File

@@ -9,19 +9,19 @@ class Agg:
@staticmethod @staticmethod
def sum(z): def sum(z):
return jnp.sum(z, axis=0, where=~jnp.isnan(z)) return jnp.sum(z, axis=0, where=~jnp.isnan(z), initial=0)
@staticmethod @staticmethod
def product(z): def product(z):
return jnp.prod(z, axis=0, where=~jnp.isnan(z)) return jnp.prod(z, axis=0, where=~jnp.isnan(z), initial=1)
@staticmethod @staticmethod
def max(z): def max(z):
return jnp.max(z, axis=0, where=~jnp.isnan(z)) return jnp.max(z, axis=0, where=~jnp.isnan(z), initial=-jnp.inf)
@staticmethod @staticmethod
def min(z): def min(z):
return jnp.min(z, axis=0, where=~jnp.isnan(z)) return jnp.min(z, axis=0, where=~jnp.isnan(z), initial=jnp.inf)
@staticmethod @staticmethod
def maxabs(z): def maxabs(z):

View File

@@ -36,6 +36,9 @@ class State:
def __setstate__(self, state): def __setstate__(self, state):
self.__dict__["state_dict"] = state self.__dict__["state_dict"] = state
def __contains__(self, item):
return item in self.state_dict
def tree_flatten(self): def tree_flatten(self):
children = list(self.state_dict.values()) children = list(self.state_dict.values())
aux_data = list(self.state_dict.keys()) aux_data = list(self.state_dict.keys())

View File

@@ -19,6 +19,21 @@ class StatefulBaseClass:
with open(path, "wb") as f: with open(path, "wb") as f:
pickle.dump(self, f) pickle.dump(self, f)
def __getstate__(self):
# only pickle the picklable attributes
state = self.__dict__.copy()
non_picklable_keys = []
for key, value in state.items():
try:
pickle.dumps(value)
except Exception:
non_picklable_keys.append(key)
for key in non_picklable_keys:
state.pop(key)
return state
def show_config(self): def show_config(self):
config = {} config = {}
for key, value in self.__dict__.items(): for key, value in self.__dict__.items():

View File

@@ -36,6 +36,7 @@ def unflatten_conns(nodes, conns):
return unflatten return unflatten
# TODO: strange implementation
def attach_with_inf(arr, idx): def attach_with_inf(arr, idx):
expand_size = arr.ndim - idx.ndim expand_size = arr.ndim - idx.ndim
expand_idx = jnp.expand_dims( expand_idx = jnp.expand_dims(
@@ -199,3 +200,14 @@ def delete_conn_by_pos(conns, pos):
Delete the connection by its idx. Delete the connection by its idx.
""" """
return conns.at[pos].set(jnp.nan) return conns.at[pos].set(jnp.nan)
def hash_array(arr: Array):
arr = jax.lax.bitcast_convert_type(arr, jnp.uint32)
def update(i, hash_val):
return hash_val ^ (
arr[i] + jnp.uint32(0x9E3779B9) + (hash_val << 6) + (hash_val >> 2)
)
return jax.lax.fori_loop(0, arr.size, update, jnp.uint32(0))