change a lot

This commit is contained in:
wls2002
2023-07-17 17:39:12 +08:00
parent a0a1ef6c58
commit f4763ebcea
21 changed files with 1060 additions and 4 deletions

View File

@@ -0,0 +1,2 @@
from .base import BaseGene
from .normal import NormalGene

View File

@@ -0,0 +1,108 @@
import jax.numpy as jnp
class Activation:
@staticmethod
def sigmoid_act(z):
z = jnp.clip(z * 5, -60, 60)
return 1 / (1 + jnp.exp(-z))
@staticmethod
def tanh_act(z):
z = jnp.clip(z * 2.5, -60, 60)
return jnp.tanh(z)
@staticmethod
def sin_act(z):
z = jnp.clip(z * 5, -60, 60)
return jnp.sin(z)
@staticmethod
def gauss_act(z):
z = jnp.clip(z * 5, -3.4, 3.4)
return jnp.exp(-z ** 2)
@staticmethod
def relu_act(z):
return jnp.maximum(z, 0)
@staticmethod
def elu_act(z):
return jnp.where(z > 0, z, jnp.exp(z) - 1)
@staticmethod
def lelu_act(z):
leaky = 0.005
return jnp.where(z > 0, z, leaky * z)
@staticmethod
def selu_act(z):
lam = 1.0507009873554804934193349852946
alpha = 1.6732632423543772848170429916717
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
@staticmethod
def softplus_act(z):
z = jnp.clip(z * 5, -60, 60)
return 0.2 * jnp.log(1 + jnp.exp(z))
@staticmethod
def identity_act(z):
return z
@staticmethod
def clamped_act(z):
return jnp.clip(z, -1, 1)
@staticmethod
def inv_act(z):
z = jnp.maximum(z, 1e-7)
return 1 / z
@staticmethod
def log_act(z):
z = jnp.maximum(z, 1e-7)
return jnp.log(z)
@staticmethod
def exp_act(z):
z = jnp.clip(z, -60, 60)
return jnp.exp(z)
@staticmethod
def abs_act(z):
return jnp.abs(z)
@staticmethod
def hat_act(z):
return jnp.maximum(0, 1 - jnp.abs(z))
@staticmethod
def square_act(z):
return z ** 2
@staticmethod
def cube_act(z):
return z ** 3
name2func = {
'sigmoid': sigmoid_act,
'tanh': tanh_act,
'sin': sin_act,
'gauss': gauss_act,
'relu': relu_act,
'elu': elu_act,
'lelu': lelu_act,
'selu': selu_act,
'softplus': softplus_act,
'identity': identity_act,
'clamped': clamped_act,
'inv': inv_act,
'log': log_act,
'exp': exp_act,
'abs': abs_act,
'hat': hat_act,
'square': square_act,
'cube': cube_act,
}

View File

@@ -0,0 +1,60 @@
import jax.numpy as jnp
class Aggregation:
@staticmethod
def sum_agg(z):
z = jnp.where(jnp.isnan(z), 0, z)
return jnp.sum(z, axis=0)
@staticmethod
def product_agg(z):
z = jnp.where(jnp.isnan(z), 1, z)
return jnp.prod(z, axis=0)
@staticmethod
def max_agg(z):
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
return jnp.max(z, axis=0)
@staticmethod
def min_agg(z):
z = jnp.where(jnp.isnan(z), jnp.inf, z)
return jnp.min(z, axis=0)
@staticmethod
def maxabs_agg(z):
z = jnp.where(jnp.isnan(z), 0, z)
abs_z = jnp.abs(z)
max_abs_index = jnp.argmax(abs_z)
return z[max_abs_index]
@staticmethod
def median_agg(z):
n = jnp.sum(~jnp.isnan(z), axis=0)
z = jnp.sort(z) # sort
idx1, idx2 = (n - 1) // 2, n // 2
median = (z[idx1] + z[idx2]) / 2
return median
@staticmethod
def mean_agg(z):
aux = jnp.where(jnp.isnan(z), 0, z)
valid_values_sum = jnp.sum(aux, axis=0)
valid_values_count = jnp.sum(~jnp.isnan(z), axis=0)
mean_without_zeros = valid_values_sum / valid_values_count
return mean_without_zeros
name2func = {
'sum': sum_agg,
'product': product_agg,
'max': max_agg,
'min': min_agg,
'maxabs': maxabs_agg,
'median': median_agg,
'mean': mean_agg,
}

View File

@@ -0,0 +1,38 @@
from jax import Array, numpy as jnp
class BaseGene:
node_attrs = []
conn_attrs = []
@staticmethod
def setup(state, config):
return state
@staticmethod
def new_node_attrs(state):
return jnp.zeros(0)
@staticmethod
def new_conn_attrs(state):
return jnp.zeros(0)
@staticmethod
def mutate_node(state, attrs: Array, key):
return attrs
@staticmethod
def mutate_conn(state, attrs: Array, key):
return attrs
@staticmethod
def distance_node(state, array: Array):
return array
@staticmethod
def distance_conn(state, array: Array):
return array
@staticmethod
def forward(state, array: Array):
return array

View File

@@ -0,0 +1,40 @@
from jax import Array, numpy as jnp
from . import BaseGene
class NormalGene(BaseGene):
node_attrs = ['bias', 'response', 'aggregation', 'activation']
conn_attrs = ['weight']
@staticmethod
def setup(state, config):
return state
@staticmethod
def new_node_attrs(state):
return jnp.array([0, 0, 0, 0])
@staticmethod
def new_conn_attrs(state):
return jnp.array([0])
@staticmethod
def mutate_node(state, attrs: Array, key):
return attrs
@staticmethod
def mutate_conn(state, attrs: Array, key):
return attrs
@staticmethod
def distance_node(state, array: Array):
return array
@staticmethod
def distance_conn(state, array: Array):
return array
@staticmethod
def forward(state, array: Array):
return array