fix bugs
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import State, StatefulBaseClass
|
||||
from utils import State, StatefulBaseClass, hash_array
|
||||
|
||||
|
||||
class BaseGene(StatefulBaseClass):
|
||||
@@ -43,3 +43,6 @@ class BaseGene(StatefulBaseClass):
|
||||
|
||||
def repr(self, state, gene, precision=2):
|
||||
raise NotImplementedError
|
||||
|
||||
def hash(self, gene):
|
||||
return hash_array(gene)
|
||||
|
||||
@@ -2,7 +2,7 @@ import numpy as np
|
||||
import jax, jax.numpy as jnp
|
||||
from ..gene import BaseNodeGene, BaseConnGene
|
||||
from ..ga import BaseMutation, BaseCrossover
|
||||
from utils import State, StatefulBaseClass, topological_sort_python
|
||||
from utils import State, StatefulBaseClass, topological_sort_python, hash_array
|
||||
|
||||
|
||||
class BaseGenome(StatefulBaseClass):
|
||||
@@ -255,10 +255,14 @@ class BaseGenome(StatefulBaseClass):
|
||||
|
||||
nx.draw(
|
||||
G,
|
||||
with_labels=True,
|
||||
pos=rotated_pos,
|
||||
node_size=node_sizes,
|
||||
node_color=node_colors,
|
||||
**kwargs,
|
||||
)
|
||||
plt.savefig(save_path, dpi=save_dpi)
|
||||
|
||||
def hash(self, nodes, conns):
|
||||
nodes_hashs = jax.vmap(hash_array)(nodes)
|
||||
conns_hashs = jax.vmap(hash_array)(conns)
|
||||
return hash_array(jnp.concatenate([nodes_hashs, conns_hashs]))
|
||||
|
||||
@@ -210,7 +210,14 @@ class DefaultGenome(BaseGenome):
|
||||
new_transformed,
|
||||
)
|
||||
|
||||
def sympy_func(self, state, network, sympy_input_transform=None, sympy_output_transform=None, backend="jax"):
|
||||
def sympy_func(
|
||||
self,
|
||||
state,
|
||||
network,
|
||||
sympy_input_transform=None,
|
||||
sympy_output_transform=None,
|
||||
backend="jax",
|
||||
):
|
||||
|
||||
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
|
||||
module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP
|
||||
@@ -219,6 +226,10 @@ class DefaultGenome(BaseGenome):
|
||||
warnings.warn(
|
||||
"genome.input_transform is not None but sympy_input_transform is None!"
|
||||
)
|
||||
|
||||
if sympy_input_transform is None:
|
||||
sympy_input_transform = lambda x: x
|
||||
|
||||
if sympy_input_transform is not None:
|
||||
if not isinstance(sympy_input_transform, list):
|
||||
sympy_input_transform = [sympy_input_transform] * self.num_inputs
|
||||
@@ -231,11 +242,14 @@ class DefaultGenome(BaseGenome):
|
||||
input_idx = self.get_input_idx()
|
||||
output_idx = self.get_output_idx()
|
||||
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
|
||||
hidden_idx = [i for i in network["nodes"] if i not in input_idx and i not in output_idx]
|
||||
hidden_idx = [
|
||||
i for i in network["nodes"] if i not in input_idx and i not in output_idx
|
||||
]
|
||||
symbols = {}
|
||||
for i in network["nodes"]:
|
||||
if i in input_idx:
|
||||
symbols[i] = sp.Symbol(f"i{i - min(input_idx)}")
|
||||
symbols[-i - 1] = sp.Symbol(f"i{i - min(input_idx)}") # origin_i
|
||||
symbols[i] = sp.Symbol(f"norm{i - min(input_idx)}")
|
||||
elif i in output_idx:
|
||||
symbols[i] = sp.Symbol(f"o{i - min(output_idx)}")
|
||||
else: # hidden
|
||||
@@ -246,10 +260,9 @@ class DefaultGenome(BaseGenome):
|
||||
for i in order:
|
||||
|
||||
if i in input_idx:
|
||||
if sympy_input_transform is not None:
|
||||
nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](symbols[i])
|
||||
else:
|
||||
nodes_exprs[symbols[i]] = symbols[i]
|
||||
nodes_exprs[symbols[-i - 1]] = symbols[-i - 1] # origin equal to its symbol
|
||||
nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](symbols[-i - 1]) # normed i
|
||||
|
||||
else:
|
||||
in_conns = [c for c in network["conns"] if c[1] == i]
|
||||
node_inputs = []
|
||||
@@ -270,12 +283,13 @@ class DefaultGenome(BaseGenome):
|
||||
is_output_node=(i in output_idx),
|
||||
)
|
||||
args_symbols.update(a_s)
|
||||
|
||||
if i in output_idx and sympy_output_transform is not None:
|
||||
nodes_exprs[symbols[i]] = sympy_output_transform(
|
||||
nodes_exprs[symbols[i]]
|
||||
)
|
||||
|
||||
input_symbols = [v for k, v in symbols.items() if k in input_idx]
|
||||
input_symbols = [symbols[-i - 1] for i in input_idx]
|
||||
reduced_exprs = nodes_exprs.copy()
|
||||
for i in order:
|
||||
reduced_exprs[symbols[i]] = reduced_exprs[symbols[i]].subs(reduced_exprs)
|
||||
@@ -299,7 +313,9 @@ class DefaultGenome(BaseGenome):
|
||||
|
||||
fixed_args_output_funcs.append(f)
|
||||
|
||||
forward_func = lambda inputs: jnp.array([f(inputs) for f in fixed_args_output_funcs])
|
||||
forward_func = lambda inputs: jnp.array(
|
||||
[f(inputs) for f in fixed_args_output_funcs]
|
||||
)
|
||||
|
||||
return (
|
||||
symbols,
|
||||
|
||||
@@ -2,8 +2,6 @@ import jax, jax.numpy as jnp
|
||||
from utils import State
|
||||
from .. import BaseAlgorithm
|
||||
from .species import *
|
||||
from .ga import *
|
||||
from .genome import *
|
||||
|
||||
|
||||
class NEAT(BaseAlgorithm):
|
||||
@@ -16,28 +14,13 @@ class NEAT(BaseAlgorithm):
|
||||
|
||||
def setup(self, state=State()):
|
||||
state = self.species.setup(state)
|
||||
state = state.register(
|
||||
generation=jnp.array(0.0),
|
||||
next_node_key=jnp.array(
|
||||
max(*self.genome.input_idx, *self.genome.output_idx) + 2,
|
||||
dtype=jnp.float32,
|
||||
),
|
||||
)
|
||||
return state
|
||||
|
||||
def ask(self, state: State):
|
||||
return self.species.ask(state)
|
||||
|
||||
def tell(self, state: State, fitness):
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
|
||||
state = state.update(generation=state.generation + 1, randkey=randkey)
|
||||
|
||||
state, winner, loser, elite_mask = self.species.update_species(state, fitness)
|
||||
state = self.create_next_generation(state, winner, loser, elite_mask)
|
||||
state = self.species.speciate(state)
|
||||
|
||||
return state
|
||||
return self.species.tell(state, fitness)
|
||||
|
||||
def transform(self, state, individual):
|
||||
"""transform the genome into a neural network"""
|
||||
@@ -65,50 +48,6 @@ class NEAT(BaseAlgorithm):
|
||||
def pop_size(self):
|
||||
return self.species.pop_size
|
||||
|
||||
def create_next_generation(self, state, winner, loser, elite_mask):
|
||||
# prepare random keys
|
||||
pop_size = self.species.pop_size
|
||||
new_node_keys = jnp.arange(pop_size) + state.next_node_key
|
||||
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
crossover_randkeys = jax.random.split(k1, pop_size)
|
||||
mutate_randkeys = jax.random.split(k2, pop_size)
|
||||
|
||||
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
|
||||
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
|
||||
|
||||
# batch crossover
|
||||
n_nodes, n_conns = jax.vmap(
|
||||
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
|
||||
)(
|
||||
state, crossover_randkeys, wpn, wpc, lpn, lpc
|
||||
) # new_nodes, new_conns
|
||||
|
||||
# batch mutation
|
||||
m_n_nodes, m_n_conns = jax.vmap(
|
||||
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
|
||||
)(
|
||||
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
|
||||
) # mutated_new_nodes, mutated_new_conns
|
||||
|
||||
# elitism don't mutate
|
||||
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
|
||||
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
|
||||
|
||||
# update next node key
|
||||
all_nodes_keys = pop_nodes[:, :, 0]
|
||||
max_node_key = jnp.max(
|
||||
jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys)
|
||||
)
|
||||
next_node_key = max_node_key + 1
|
||||
|
||||
return state.update(
|
||||
randkey=randkey,
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
next_node_key=next_node_key,
|
||||
)
|
||||
|
||||
def member_count(self, state: State):
|
||||
return state.member_count
|
||||
|
||||
|
||||
@@ -10,6 +10,9 @@ class BaseSpecies(StatefulBaseClass):
|
||||
def ask(self, state: State):
|
||||
raise NotImplementedError
|
||||
|
||||
def tell(self, state: State, fitness):
|
||||
raise NotImplementedError
|
||||
|
||||
def update_species(self, state, fitness):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -113,12 +113,23 @@ class DefaultSpecies(BaseSpecies):
|
||||
idx2species=idx2species,
|
||||
center_nodes=center_nodes,
|
||||
center_conns=center_conns,
|
||||
next_species_key=jnp.array(1), # 0 is reserved for the first species
|
||||
next_species_key=jnp.float32(1), # 0 is reserved for the first species
|
||||
generation=jnp.float32(0),
|
||||
)
|
||||
|
||||
def ask(self, state):
|
||||
return state.pop_nodes, state.pop_conns
|
||||
|
||||
def tell(self, state, fitness):
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
|
||||
state = state.update(generation=state.generation + 1, randkey=randkey)
|
||||
state, winner, loser, elite_mask = self.update_species(state, fitness)
|
||||
state = self.create_next_generation(state, winner, loser, elite_mask)
|
||||
state = self.speciate(state)
|
||||
|
||||
return state
|
||||
|
||||
def update_species(self, state, fitness):
|
||||
# update the fitness of each species
|
||||
state, species_fitness = self.update_species_fitness(state, fitness)
|
||||
@@ -619,3 +630,43 @@ class DefaultSpecies(BaseSpecies):
|
||||
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
|
||||
|
||||
return val
|
||||
|
||||
def create_next_generation(self, state, winner, loser, elite_mask):
|
||||
|
||||
# find next node key
|
||||
all_nodes_keys = state.pop_nodes[:, :, 0]
|
||||
max_node_key = jnp.max(all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0)
|
||||
next_node_key = max_node_key + 1
|
||||
new_node_keys = jnp.arange(self.pop_size) + next_node_key
|
||||
|
||||
# prepare random keys
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
crossover_randkeys = jax.random.split(k1, self.pop_size)
|
||||
mutate_randkeys = jax.random.split(k2, self.pop_size)
|
||||
|
||||
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
|
||||
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
|
||||
|
||||
# batch crossover
|
||||
n_nodes, n_conns = jax.vmap(
|
||||
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
|
||||
)(
|
||||
state, crossover_randkeys, wpn, wpc, lpn, lpc
|
||||
) # new_nodes, new_conns
|
||||
|
||||
# batch mutation
|
||||
m_n_nodes, m_n_conns = jax.vmap(
|
||||
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
|
||||
)(
|
||||
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
|
||||
) # mutated_new_nodes, mutated_new_conns
|
||||
|
||||
# elitism don't mutate
|
||||
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
|
||||
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
|
||||
|
||||
return state.update(
|
||||
randkey=randkey,
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
)
|
||||
@@ -8,7 +8,7 @@ if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=DefaultGenome(
|
||||
genome=DenseInitialize(
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
max_nodes=50,
|
||||
@@ -21,7 +21,7 @@ if __name__ == "__main__":
|
||||
# aggregation_options=(Agg.sum,),
|
||||
aggregation_options=AGG_ALL,
|
||||
),
|
||||
output_transform=Act.sigmoid, # the activation function for output node
|
||||
output_transform=Act.standard_sigmoid, # the activation function for output node
|
||||
mutation=DefaultMutation(
|
||||
node_add=0.1,
|
||||
conn_add=0.1,
|
||||
@@ -29,7 +29,7 @@ if __name__ == "__main__":
|
||||
conn_delete=0,
|
||||
),
|
||||
),
|
||||
pop_size=100000,
|
||||
pop_size=10000,
|
||||
species_size=20,
|
||||
compatibility_threshold=2,
|
||||
survival_threshold=0.01, # magic
|
||||
@@ -37,7 +37,7 @@ if __name__ == "__main__":
|
||||
),
|
||||
problem=XOR3d(),
|
||||
generation_limit=10000,
|
||||
fitness_target=-1e-8,
|
||||
fitness_target=-1e-3,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
|
||||
@@ -6,7 +6,7 @@ from algorithm.neat import *
|
||||
from problem.rl_env import GymNaxEnv
|
||||
|
||||
|
||||
def action_policy(forward_func, obs):
|
||||
def action_policy(randkey, forward_func, obs):
|
||||
return jnp.argmax(forward_func(obs))
|
||||
|
||||
|
||||
@@ -27,7 +27,9 @@ if __name__ == "__main__":
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=GymNaxEnv(env_name="CartPole-v1", repeat_times=5, action_policy=action_policy),
|
||||
problem=GymNaxEnv(
|
||||
env_name="CartPole-v1", repeat_times=5, action_policy=action_policy
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=500,
|
||||
)
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from algorithm.neat import *
|
||||
from algorithm.neat.genome.hidden import AdvanceInitialize
|
||||
from algorithm.neat.genome.dense import DenseInitialize
|
||||
from utils.graph import topological_sort_python
|
||||
from utils import *
|
||||
|
||||
if __name__ == '__main__':
|
||||
genome = AdvanceInitialize(
|
||||
num_inputs=17,
|
||||
num_outputs=6,
|
||||
hidden_cnt=8,
|
||||
if __name__ == "__main__":
|
||||
genome = DenseInitialize(
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
max_nodes=50,
|
||||
max_conns=500,
|
||||
)
|
||||
@@ -19,16 +19,19 @@ if __name__ == '__main__':
|
||||
nodes, conns = genome.initialize(state, randkey)
|
||||
|
||||
network = genome.network_dict(state, nodes, conns)
|
||||
print(set(network["nodes"]), set(network["conns"]))
|
||||
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
|
||||
print(order)
|
||||
|
||||
input_idx, output_idx = genome.get_input_idx(), genome.get_output_idx()
|
||||
print(input_idx, output_idx)
|
||||
|
||||
print(genome.repr(state, nodes, conns))
|
||||
print(network)
|
||||
res = genome.sympy_func(state, network, sympy_input_transform=lambda x: 999999999*x, sympy_output_transform=SympyStandardSigmoid)
|
||||
(symbols,
|
||||
args_symbols,
|
||||
input_symbols,
|
||||
nodes_exprs,
|
||||
output_exprs,
|
||||
forward_func,) = res
|
||||
|
||||
res = genome.sympy_func(state, network, precision=3)
|
||||
print(res)
|
||||
print(symbols)
|
||||
print(output_exprs[0].subs(args_symbols))
|
||||
|
||||
inputs = jnp.zeros(3)
|
||||
print(forward_func(inputs))
|
||||
|
||||
@@ -71,6 +71,9 @@ class Pipeline(StatefulBaseClass):
|
||||
print(f"save to {self.save_dir}")
|
||||
if not os.path.exists(self.save_dir):
|
||||
os.makedirs(self.save_dir)
|
||||
self.genome_dir = os.path.join(self.save_dir, "genomes")
|
||||
if not os.path.exists(self.genome_dir):
|
||||
os.makedirs(self.genome_dir)
|
||||
|
||||
def setup(self, state=State()):
|
||||
print("initializing")
|
||||
@@ -165,6 +168,7 @@ class Pipeline(StatefulBaseClass):
|
||||
print("start compile")
|
||||
tic = time.time()
|
||||
compiled_step = jax.jit(self.step).lower(state).compile()
|
||||
# compiled_step = self.step
|
||||
print(
|
||||
f"compile finished, cost time: {time.time() - tic:.6f}s",
|
||||
)
|
||||
@@ -181,9 +185,21 @@ class Pipeline(StatefulBaseClass):
|
||||
|
||||
if max(fitnesses) >= self.fitness_target:
|
||||
print("Fitness limit reached!")
|
||||
return state, self.best_genome
|
||||
break
|
||||
|
||||
if self.algorithm.generation(state) >= self.generation_limit:
|
||||
print("Generation limit reached!")
|
||||
|
||||
if self.is_save:
|
||||
best_genome = jax.device_get(self.best_genome)
|
||||
with open(os.path.join(self.genome_dir, f"best_genome.npz"), "wb") as f:
|
||||
np.savez(
|
||||
f,
|
||||
nodes=best_genome[0],
|
||||
conns=best_genome[1],
|
||||
fitness=self.best_fitness,
|
||||
)
|
||||
|
||||
print("Generation limit reached!")
|
||||
return state, self.best_genome
|
||||
|
||||
def analysis(self, state, pop, fitnesses):
|
||||
@@ -206,15 +222,15 @@ class Pipeline(StatefulBaseClass):
|
||||
self.best_fitness = fitnesses[max_idx]
|
||||
self.best_genome = pop[0][max_idx], pop[1][max_idx]
|
||||
|
||||
if self.is_save:
|
||||
best_genome = jax.device_get(self.best_genome)
|
||||
with open(os.path.join(self.save_dir, f"{int(self.algorithm.generation(state))}.npz"), "wb") as f:
|
||||
np.savez(
|
||||
f,
|
||||
nodes=best_genome[0],
|
||||
conns=best_genome[1],
|
||||
fitness=self.best_fitness,
|
||||
)
|
||||
if self.is_save:
|
||||
best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx]))
|
||||
with open(os.path.join(self.genome_dir, f"{int(self.algorithm.generation(state))}.npz"), "wb") as f:
|
||||
np.savez(
|
||||
f,
|
||||
nodes=best_genome[0],
|
||||
conns=best_genome[1],
|
||||
fitness=self.best_fitness,
|
||||
)
|
||||
|
||||
# save best if save path is not None
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from utils.aggregation.agg_jnp import Agg, agg_func, AGG_ALL
|
||||
from .tools import *
|
||||
from .graph import *
|
||||
@@ -15,7 +13,9 @@ from typing import Union
|
||||
|
||||
name2sympy = {
|
||||
"sigmoid": SympySigmoid,
|
||||
"standard_sigmoid": SympyStandardSigmoid,
|
||||
"tanh": SympyTanh,
|
||||
"standard_tanh": SympyStandardTanh,
|
||||
"sin": SympySin,
|
||||
"relu": SympyRelu,
|
||||
"lelu": SympyLelu,
|
||||
|
||||
@@ -12,19 +12,26 @@ class Act:
|
||||
|
||||
@staticmethod
|
||||
def sigmoid(z):
|
||||
z = jnp.clip(5 * z / sigma_3, -5, 5)
|
||||
z = 5 * z / sigma_3
|
||||
z = 1 / (1 + jnp.exp(-z))
|
||||
|
||||
return z * sigma_3 # (0, sigma_3)
|
||||
|
||||
@staticmethod
|
||||
def standard_sigmoid(z):
|
||||
z = 5 * z / sigma_3
|
||||
z = 1 / (1 + jnp.exp(-z))
|
||||
|
||||
return z # (0, 1)
|
||||
|
||||
@staticmethod
|
||||
def tanh(z):
|
||||
z = jnp.clip(5 * z / sigma_3, -5, 5)
|
||||
z = 5 * z / sigma_3
|
||||
return jnp.tanh(z) * sigma_3 # (-sigma_3, sigma_3)
|
||||
|
||||
@staticmethod
|
||||
def standard_tanh(z):
|
||||
z = jnp.clip(5 * z / sigma_3, -5, 5)
|
||||
z =5 * z / sigma_3
|
||||
return jnp.tanh(z) # (-1, 1)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -25,21 +25,16 @@ class SympyClip(sp.Function):
|
||||
return rf"\mathrm{{clip}}\left({sp.latex(self.args[0])}, {self.args[1]}, {self.args[2]}\right)"
|
||||
|
||||
|
||||
class SympySigmoid(sp.Function):
|
||||
class SympySigmoid_(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = SympyClip(5 * z / sigma_3, -5, 5)
|
||||
z = 1 / (1 + sp.exp(-z))
|
||||
return z * sigma_3
|
||||
return None
|
||||
z = 1 / (1 + sp.exp(-z))
|
||||
return z
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z, backend=np):
|
||||
z = backend.clip(5 * z / sigma_3, -5, 5)
|
||||
z = 1 / (1 + backend.exp(-z))
|
||||
|
||||
return z * sigma_3 # (0, sigma_3)
|
||||
return z
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"sigmoid({self.args[0]})"
|
||||
@@ -48,32 +43,47 @@ class SympySigmoid(sp.Function):
|
||||
return rf"\mathrm{{sigmoid}}\left({sp.latex(self.args[0])}\right)"
|
||||
|
||||
|
||||
class SympySigmoid(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return SympySigmoid_(5 * z / sigma_3) * sigma_3
|
||||
|
||||
|
||||
class SympyStandardSigmoid(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return SympySigmoid_(5 * z / sigma_3)
|
||||
|
||||
# @staticmethod
|
||||
# def numerical_eval(z, backend=np):
|
||||
# z = backend.clip(5 * z / sigma_3, -5, 5)
|
||||
# z = 1 / (1 + backend.exp(-z))
|
||||
#
|
||||
# return z # (0, 1)
|
||||
|
||||
|
||||
class SympyTanh(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = SympyClip(5 * z / sigma_3, -5, 5)
|
||||
return sp.tanh(z) * sigma_3
|
||||
return None
|
||||
z = 5 * z / sigma_3
|
||||
return sp.tanh(z) * sigma_3
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z, backend=np):
|
||||
z = backend.clip(5 * z / sigma_3, -5, 5)
|
||||
return backend.tanh(z) * sigma_3 # (-sigma_3, sigma_3)
|
||||
# @staticmethod
|
||||
# def numerical_eval(z, backend=np):
|
||||
# z = backend.clip(5 * z / sigma_3, -5, 5)
|
||||
# return backend.tanh(z) * sigma_3 # (-sigma_3, sigma_3)
|
||||
|
||||
|
||||
class SympyStandardTanh(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = SympyClip(5 * z / sigma_3, -5, 5)
|
||||
return sp.tanh(z)
|
||||
return None
|
||||
z = 5 * z / sigma_3
|
||||
return sp.tanh(z)
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z, backend=np):
|
||||
z = backend.clip(5 * z / sigma_3, -5, 5)
|
||||
return backend.tanh(z) # (-1, 1)
|
||||
# @staticmethod
|
||||
# def numerical_eval(z, backend=np):
|
||||
# z = backend.clip(5 * z / sigma_3, -5, 5)
|
||||
# return backend.tanh(z) # (-1, 1)
|
||||
|
||||
|
||||
class SympySin(sp.Function):
|
||||
|
||||
@@ -9,19 +9,19 @@ class Agg:
|
||||
|
||||
@staticmethod
|
||||
def sum(z):
|
||||
return jnp.sum(z, axis=0, where=~jnp.isnan(z))
|
||||
return jnp.sum(z, axis=0, where=~jnp.isnan(z), initial=0)
|
||||
|
||||
@staticmethod
|
||||
def product(z):
|
||||
return jnp.prod(z, axis=0, where=~jnp.isnan(z))
|
||||
return jnp.prod(z, axis=0, where=~jnp.isnan(z), initial=1)
|
||||
|
||||
@staticmethod
|
||||
def max(z):
|
||||
return jnp.max(z, axis=0, where=~jnp.isnan(z))
|
||||
return jnp.max(z, axis=0, where=~jnp.isnan(z), initial=-jnp.inf)
|
||||
|
||||
@staticmethod
|
||||
def min(z):
|
||||
return jnp.min(z, axis=0, where=~jnp.isnan(z))
|
||||
return jnp.min(z, axis=0, where=~jnp.isnan(z), initial=jnp.inf)
|
||||
|
||||
@staticmethod
|
||||
def maxabs(z):
|
||||
|
||||
@@ -36,6 +36,9 @@ class State:
|
||||
def __setstate__(self, state):
|
||||
self.__dict__["state_dict"] = state
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self.state_dict
|
||||
|
||||
def tree_flatten(self):
|
||||
children = list(self.state_dict.values())
|
||||
aux_data = list(self.state_dict.keys())
|
||||
|
||||
@@ -19,6 +19,21 @@ class StatefulBaseClass:
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(self, f)
|
||||
|
||||
def __getstate__(self):
|
||||
# only pickle the picklable attributes
|
||||
state = self.__dict__.copy()
|
||||
non_picklable_keys = []
|
||||
for key, value in state.items():
|
||||
try:
|
||||
pickle.dumps(value)
|
||||
except Exception:
|
||||
non_picklable_keys.append(key)
|
||||
|
||||
for key in non_picklable_keys:
|
||||
state.pop(key)
|
||||
|
||||
return state
|
||||
|
||||
def show_config(self):
|
||||
config = {}
|
||||
for key, value in self.__dict__.items():
|
||||
|
||||
@@ -36,6 +36,7 @@ def unflatten_conns(nodes, conns):
|
||||
return unflatten
|
||||
|
||||
|
||||
# TODO: strange implementation
|
||||
def attach_with_inf(arr, idx):
|
||||
expand_size = arr.ndim - idx.ndim
|
||||
expand_idx = jnp.expand_dims(
|
||||
@@ -199,3 +200,14 @@ def delete_conn_by_pos(conns, pos):
|
||||
Delete the connection by its idx.
|
||||
"""
|
||||
return conns.at[pos].set(jnp.nan)
|
||||
|
||||
|
||||
def hash_array(arr: Array):
|
||||
arr = jax.lax.bitcast_convert_type(arr, jnp.uint32)
|
||||
|
||||
def update(i, hash_val):
|
||||
return hash_val ^ (
|
||||
arr[i] + jnp.uint32(0x9E3779B9) + (hash_val << 6) + (hash_val >> 2)
|
||||
)
|
||||
|
||||
return jax.lax.fori_loop(0, arr.size, update, jnp.uint32(0))
|
||||
|
||||
Reference in New Issue
Block a user