new architecture

This commit is contained in:
wls2002
2024-01-27 00:52:39 +08:00
parent 4efe9a53c1
commit aac41a089d
65 changed files with 1651 additions and 1783 deletions

View File

@@ -1,4 +1,5 @@
from .activation import Act, act
from .aggregation import Agg, agg
from .tools import *
from .graph import *
from .graph import *
from .state import State

View File

@@ -57,10 +57,8 @@ def agg(idx, z, agg_funcs):
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
def all_nan():
return 0.
def not_all_nan():
return jax.lax.switch(idx, agg_funcs, z)
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
return jax.lax.cond(
jnp.all(jnp.isnan(z)),
lambda: jnp.nan, # all inputs are nan
lambda: jax.lax.switch(idx, agg_funcs, z) # otherwise
)

29
utils/state.py Normal file
View File

@@ -0,0 +1,29 @@
from jax.tree_util import register_pytree_node_class
@register_pytree_node_class
class State:
def __init__(self, **kwargs):
self.__dict__['state_dict'] = kwargs
def update(self, **kwargs):
return State(**{**self.state_dict, **kwargs})
def __getattr__(self, name):
return self.state_dict[name]
def __setattr__(self, name, value):
raise AttributeError("State is immutable")
def __repr__(self):
return f"State ({self.state_dict})"
def tree_flatten(self):
children = list(self.state_dict.values())
aux_data = list(self.state_dict.keys())
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(**dict(zip(aux_data, children)))

View File

@@ -5,13 +5,11 @@ import jax
from jax import numpy as jnp, Array, jit, vmap
I_INT = np.iinfo(jnp.int32).max # infinite int
EMPTY_NODE = np.full((1, 5), jnp.nan)
EMPTY_CON = np.full((1, 4), jnp.nan)
def unflatten_conns(nodes, conns):
"""
transform the (C, CL) connections to (CL-2, N, N)
transform the (C, CL) connections to (CL-2, N, N), 2 is for the input index and output index)
:return:
"""
N = nodes.shape[0]
@@ -66,4 +64,43 @@ def rank_elements(array, reverse=False):
"""
if not reverse:
array = -array
return jnp.argsort(jnp.argsort(array))
return jnp.argsort(jnp.argsort(array))
@jit
def mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate):
k1, k2, k3 = jax.random.split(key, num=3)
noise = jax.random.normal(k1, ()) * mutate_power
replace = jax.random.normal(k2, ()) * init_std + init_mean
r = jax.random.uniform(k3, ())
val = jnp.where(
r < mutate_rate,
val + noise,
jnp.where(
(mutate_rate < r) & (r < mutate_rate + replace_rate),
replace,
val
)
)
return val
@jit
def mutate_int(key, val, options, replace_rate):
k1, k2 = jax.random.split(key, num=2)
r = jax.random.uniform(k1, ())
val = jnp.where(
r < replace_rate,
jax.random.choice(k2, options),
val
)
return val
def argmin_with_mask(arr, mask):
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx