change repo structure; modify readme
This commit is contained in:
5
tensorneat/utils/__init__.py
Normal file
5
tensorneat/utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .activation import Act, act
|
||||
from .aggregation import Agg, agg
|
||||
from .tools import *
|
||||
from .graph import *
|
||||
from .state import State
|
||||
83
tensorneat/utils/activation.py
Normal file
83
tensorneat/utils/activation.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class Act:
|
||||
|
||||
@staticmethod
|
||||
def sigmoid(z):
|
||||
z = jnp.clip(5 * z, -10, 10)
|
||||
return 1 / (1 + jnp.exp(-z))
|
||||
|
||||
@staticmethod
|
||||
def tanh(z):
|
||||
return jnp.tanh(z)
|
||||
|
||||
@staticmethod
|
||||
def sin(z):
|
||||
return jnp.sin(z)
|
||||
|
||||
@staticmethod
|
||||
def relu(z):
|
||||
return jnp.maximum(z, 0)
|
||||
|
||||
@staticmethod
|
||||
def lelu(z):
|
||||
leaky = 0.005
|
||||
return jnp.where(z > 0, z, leaky * z)
|
||||
|
||||
@staticmethod
|
||||
def identity(z):
|
||||
return z
|
||||
|
||||
@staticmethod
|
||||
def clamped(z):
|
||||
return jnp.clip(z, -1, 1)
|
||||
|
||||
@staticmethod
|
||||
def inv(z):
|
||||
z = jnp.where(
|
||||
z > 0,
|
||||
jnp.maximum(z, 1e-7),
|
||||
jnp.minimum(z, -1e-7)
|
||||
)
|
||||
return 1 / z
|
||||
|
||||
@staticmethod
|
||||
def log(z):
|
||||
z = jnp.maximum(z, 1e-7)
|
||||
return jnp.log(z)
|
||||
|
||||
@staticmethod
|
||||
def exp(z):
|
||||
z = jnp.clip(z, -10, 10)
|
||||
return jnp.exp(z)
|
||||
|
||||
@staticmethod
|
||||
def abs(z):
|
||||
return jnp.abs(z)
|
||||
|
||||
|
||||
ACT_ALL = (
|
||||
Act.sigmoid,
|
||||
Act.tanh,
|
||||
Act.sin,
|
||||
Act.relu,
|
||||
Act.lelu,
|
||||
Act.identity,
|
||||
Act.clamped,
|
||||
Act.inv,
|
||||
Act.log,
|
||||
Act.exp,
|
||||
Act.abs,
|
||||
)
|
||||
|
||||
|
||||
def act(idx, z, act_funcs):
|
||||
"""
|
||||
calculate activation function for each node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
# change idx from float to int
|
||||
res = jax.lax.switch(idx, act_funcs, z)
|
||||
return res
|
||||
67
tensorneat/utils/aggregation.py
Normal file
67
tensorneat/utils/aggregation.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class Agg:
|
||||
|
||||
@staticmethod
|
||||
def sum(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
return jnp.sum(z, axis=0)
|
||||
|
||||
@staticmethod
|
||||
def product(z):
|
||||
z = jnp.where(jnp.isnan(z), 1, z)
|
||||
return jnp.prod(z, axis=0)
|
||||
|
||||
@staticmethod
|
||||
def max(z):
|
||||
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
|
||||
return jnp.max(z, axis=0)
|
||||
|
||||
@staticmethod
|
||||
def min(z):
|
||||
z = jnp.where(jnp.isnan(z), jnp.inf, z)
|
||||
return jnp.min(z, axis=0)
|
||||
|
||||
@staticmethod
|
||||
def maxabs(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
abs_z = jnp.abs(z)
|
||||
max_abs_index = jnp.argmax(abs_z)
|
||||
return z[max_abs_index]
|
||||
|
||||
@staticmethod
|
||||
def median(z):
|
||||
n = jnp.sum(~jnp.isnan(z), axis=0)
|
||||
|
||||
z = jnp.sort(z) # sort
|
||||
|
||||
idx1, idx2 = (n - 1) // 2, n // 2
|
||||
median = (z[idx1] + z[idx2]) / 2
|
||||
|
||||
return median
|
||||
|
||||
@staticmethod
|
||||
def mean(z):
|
||||
aux = jnp.where(jnp.isnan(z), 0, z)
|
||||
valid_values_sum = jnp.sum(aux, axis=0)
|
||||
valid_values_count = jnp.sum(~jnp.isnan(z), axis=0)
|
||||
mean_without_zeros = valid_values_sum / valid_values_count
|
||||
return mean_without_zeros
|
||||
|
||||
|
||||
AGG_ALL = (Agg.sum, Agg.product, Agg.max, Agg.min, Agg.maxabs, Agg.median, Agg.mean)
|
||||
|
||||
|
||||
def agg(idx, z, agg_funcs):
|
||||
"""
|
||||
calculate activation function for inputs of node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
|
||||
return jax.lax.cond(
|
||||
jnp.all(jnp.isnan(z)),
|
||||
lambda: jnp.nan, # all inputs are nan
|
||||
lambda: jax.lax.switch(idx, agg_funcs, z) # otherwise
|
||||
)
|
||||
68
tensorneat/utils/graph.py
Normal file
68
tensorneat/utils/graph.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Some graph algorithm implemented in jax.
|
||||
Only used in feed-forward networks.
|
||||
"""
|
||||
|
||||
import jax
|
||||
from jax import jit, Array, numpy as jnp
|
||||
|
||||
from .tools import fetch_first, I_INT
|
||||
|
||||
|
||||
@jit
|
||||
def topological_sort(nodes: Array, conns: Array) -> Array:
|
||||
"""
|
||||
a jit-able version of topological_sort!
|
||||
conns: Array[N, N]
|
||||
"""
|
||||
|
||||
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(conns, axis=0))
|
||||
res = jnp.full(in_degree.shape, I_INT)
|
||||
|
||||
def cond_fun(carry):
|
||||
res_, idx_, in_degree_ = carry
|
||||
i = fetch_first(in_degree_ == 0.)
|
||||
return i != I_INT
|
||||
|
||||
def body_func(carry):
|
||||
res_, idx_, in_degree_ = carry
|
||||
i = fetch_first(in_degree_ == 0.)
|
||||
|
||||
# add to res and flag it is already in it
|
||||
res_ = res_.at[idx_].set(i)
|
||||
in_degree_ = in_degree_.at[i].set(-1)
|
||||
|
||||
# decrease in_degree of all its children
|
||||
children = conns[i, :]
|
||||
in_degree_ = jnp.where(children, in_degree_ - 1, in_degree_)
|
||||
return res_, idx_ + 1, in_degree_
|
||||
|
||||
res, _, _ = jax.lax.while_loop(cond_fun, body_func, (res, 0, in_degree))
|
||||
return res
|
||||
|
||||
|
||||
@jit
|
||||
def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
|
||||
"""
|
||||
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
|
||||
"""
|
||||
|
||||
conns = conns.at[from_idx, to_idx].set(True)
|
||||
|
||||
visited = jnp.full(nodes.shape[0], False)
|
||||
new_visited = visited.at[to_idx].set(True)
|
||||
|
||||
def cond_func(carry):
|
||||
visited_, new_visited_ = carry
|
||||
end_cond1 = jnp.all(visited_ == new_visited_) # no new nodes been visited
|
||||
end_cond2 = new_visited_[from_idx] # the starting node has been visited
|
||||
return jnp.logical_not(end_cond1 | end_cond2)
|
||||
|
||||
def body_func(carry):
|
||||
_, visited_ = carry
|
||||
new_visited_ = jnp.dot(visited_, conns)
|
||||
new_visited_ = jnp.logical_or(visited_, new_visited_)
|
||||
return visited_, new_visited_
|
||||
|
||||
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
|
||||
return visited[from_idx]
|
||||
29
tensorneat/utils/state.py
Normal file
29
tensorneat/utils/state.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from jax.tree_util import register_pytree_node_class
|
||||
|
||||
|
||||
@register_pytree_node_class
|
||||
class State:
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__['state_dict'] = kwargs
|
||||
|
||||
def update(self, **kwargs):
|
||||
return State(**{**self.state_dict, **kwargs})
|
||||
|
||||
def __getattr__(self, name):
|
||||
return self.state_dict[name]
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
raise AttributeError("State is immutable")
|
||||
|
||||
def __repr__(self):
|
||||
return f"State ({self.state_dict})"
|
||||
|
||||
def tree_flatten(self):
|
||||
children = list(self.state_dict.values())
|
||||
aux_data = list(self.state_dict.keys())
|
||||
return children, aux_data
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
return cls(**dict(zip(aux_data, children)))
|
||||
106
tensorneat/utils/tools.py
Normal file
106
tensorneat/utils/tools.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import numpy as jnp, Array, jit, vmap
|
||||
|
||||
I_INT = np.iinfo(jnp.int32).max # infinite int
|
||||
|
||||
|
||||
def unflatten_conns(nodes, conns):
|
||||
"""
|
||||
transform the (C, CL) connections to (CL-2, N, N), 2 is for the input index and output index)
|
||||
:return:
|
||||
"""
|
||||
N = nodes.shape[0]
|
||||
CL = conns.shape[1]
|
||||
node_keys = nodes[:, 0]
|
||||
i_keys, o_keys = conns[:, 0], conns[:, 1]
|
||||
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
|
||||
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
|
||||
res = jnp.full((CL - 2, N, N), jnp.nan)
|
||||
|
||||
# Is interesting that jax use clip when attach data in array
|
||||
# however, it will do nothing set values in an array
|
||||
# put all attributes include enable in res
|
||||
res = res.at[:, i_idxs, o_idxs].set(conns[:, 2:].T)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def key_to_indices(key, keys):
|
||||
return fetch_first(key == keys)
|
||||
|
||||
|
||||
@jit
|
||||
def fetch_first(mask, default=I_INT) -> Array:
|
||||
"""
|
||||
fetch the first True index
|
||||
:param mask: array of bool
|
||||
:param default: the default value if no element satisfying the condition
|
||||
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return default value
|
||||
"""
|
||||
idx = jnp.argmax(mask)
|
||||
return jnp.where(mask[idx], idx, default)
|
||||
|
||||
|
||||
@jit
|
||||
def fetch_random(rand_key, mask, default=I_INT) -> Array:
|
||||
"""
|
||||
similar to fetch_first, but fetch a random True index
|
||||
"""
|
||||
true_cnt = jnp.sum(mask)
|
||||
cumsum = jnp.cumsum(mask)
|
||||
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
|
||||
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
|
||||
return fetch_first(mask, default)
|
||||
|
||||
|
||||
@partial(jit, static_argnames=['reverse'])
|
||||
def rank_elements(array, reverse=False):
|
||||
"""
|
||||
rank the element in the array.
|
||||
if reverse is True, the rank is from small to large. default large to small
|
||||
"""
|
||||
if not reverse:
|
||||
array = -array
|
||||
return jnp.argsort(jnp.argsort(array))
|
||||
|
||||
|
||||
@jit
|
||||
def mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate):
|
||||
k1, k2, k3 = jax.random.split(key, num=3)
|
||||
noise = jax.random.normal(k1, ()) * mutate_power
|
||||
replace = jax.random.normal(k2, ()) * init_std + init_mean
|
||||
r = jax.random.uniform(k3, ())
|
||||
|
||||
val = jnp.where(
|
||||
r < mutate_rate,
|
||||
val + noise,
|
||||
jnp.where(
|
||||
(mutate_rate < r) & (r < mutate_rate + replace_rate),
|
||||
replace,
|
||||
val
|
||||
)
|
||||
)
|
||||
|
||||
return val
|
||||
|
||||
|
||||
@jit
|
||||
def mutate_int(key, val, options, replace_rate):
|
||||
k1, k2 = jax.random.split(key, num=2)
|
||||
r = jax.random.uniform(k1, ())
|
||||
|
||||
val = jnp.where(
|
||||
r < replace_rate,
|
||||
jax.random.choice(k2, options),
|
||||
val
|
||||
)
|
||||
|
||||
return val
|
||||
|
||||
def argmin_with_mask(arr, mask):
|
||||
masked_arr = jnp.where(mask, arr, jnp.inf)
|
||||
min_idx = jnp.argmin(masked_arr)
|
||||
return min_idx
|
||||
Reference in New Issue
Block a user