complete normal neat algorithm

This commit is contained in:
wls2002
2023-07-18 23:55:36 +08:00
parent 40cf0b6fbe
commit 0a2a9fd1be
26 changed files with 880 additions and 251 deletions

View File

@@ -1,2 +1,5 @@
from .base import BaseGene
from .normal import NormalGene
from .activation import Activation
from .aggregation import Aggregation

View File

@@ -3,6 +3,8 @@ import jax.numpy as jnp
class Activation:
name2func = {}
@staticmethod
def sigmoid_act(z):
z = jnp.clip(z * 5, -60, 60)
@@ -86,23 +88,23 @@ class Activation:
def cube_act(z):
return z ** 3
name2func = {
'sigmoid': sigmoid_act,
'tanh': tanh_act,
'sin': sin_act,
'gauss': gauss_act,
'relu': relu_act,
'elu': elu_act,
'lelu': lelu_act,
'selu': selu_act,
'softplus': softplus_act,
'identity': identity_act,
'clamped': clamped_act,
'inv': inv_act,
'log': log_act,
'exp': exp_act,
'abs': abs_act,
'hat': hat_act,
'square': square_act,
'cube': cube_act,
}
Activation.name2func = {
'sigmoid': Activation.sigmoid_act,
'tanh': Activation.tanh_act,
'sin': Activation.sin_act,
'gauss': Activation.gauss_act,
'relu': Activation.relu_act,
'elu': Activation.elu_act,
'lelu': Activation.lelu_act,
'selu': Activation.selu_act,
'softplus': Activation.softplus_act,
'identity': Activation.identity_act,
'clamped': Activation.clamped_act,
'inv': Activation.inv_act,
'log': Activation.log_act,
'exp': Activation.exp_act,
'abs': Activation.abs_act,
'hat': Activation.hat_act,
'square': Activation.square_act,
'cube': Activation.cube_act,
}

View File

@@ -3,6 +3,8 @@ import jax.numpy as jnp
class Aggregation:
name2func = {}
@staticmethod
def sum_agg(z):
z = jnp.where(jnp.isnan(z), 0, z)
@@ -49,12 +51,13 @@ class Aggregation:
mean_without_zeros = valid_values_sum / valid_values_count
return mean_without_zeros
name2func = {
'sum': sum_agg,
'product': product_agg,
'max': max_agg,
'min': min_agg,
'maxabs': maxabs_agg,
'median': median_agg,
'mean': mean_agg,
}
Aggregation.name2func = {
'sum': Aggregation.sum_agg,
'product': Aggregation.product_agg,
'max': Aggregation.max_agg,
'min': Aggregation.min_agg,
'maxabs': Aggregation.maxabs_agg,
'median': Aggregation.median_agg,
'mean': Aggregation.mean_agg,
}

View File

@@ -1,4 +1,4 @@
from jax import Array, numpy as jnp
from jax import Array, numpy as jnp, vmap
class BaseGene:
@@ -26,13 +26,19 @@ class BaseGene:
return attrs
@staticmethod
def distance_node(state, array1: Array, array2: Array):
return array1
def distance_node(state, node1: Array, node2: Array):
return node1
@staticmethod
def distance_conn(state, array1: Array, array2: Array):
return array1
def distance_conn(state, conn1: Array, conn2: Array):
return conn1
@staticmethod
def forward(state, array: Array):
return array
def forward_transform(nodes, conns):
return nodes, conns
@staticmethod
def create_forward(config):
return None

View File

@@ -1,7 +1,11 @@
import jax
from jax import Array, numpy as jnp
from . import BaseGene
from .base import BaseGene
from .activation import Activation
from .aggregation import Aggregation
from ..utils import unflatten_connections, I_INT
from ..genome import topological_sort
class NormalGene(BaseGene):
@@ -70,18 +74,116 @@ class NormalGene(BaseGene):
return jnp.array([weight])
@staticmethod
def distance_node(state, array1: Array, array2: Array):
def distance_node(state, node1: Array, node2: Array):
# bias + response + activation + aggregation
return jnp.abs(array1[1] - array2[1]) + jnp.abs(array1[2] - array2[2]) + \
(array1[3] != array2[3]) + (array1[4] != array2[4])
return jnp.abs(node1[1] - node2[1]) + jnp.abs(node1[2] - node2[2]) + \
(node1[3] != node2[3]) + (node1[4] != node2[4])
@staticmethod
def distance_conn(state, array1: Array, array2: Array):
return (array1[2] != array2[2]) + jnp.abs(array1[3] - array2[3]) # enable + weight
def distance_conn(state, con1: Array, con2: Array):
return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight
@staticmethod
def forward(state, array: Array):
return array
def forward_transform(nodes, conns):
u_conns = unflatten_connections(nodes, conns)
u_conns = jnp.where(jnp.isnan(u_conns[0, :]), jnp.nan, u_conns) # enable is false, then the connections is nan
u_conns = u_conns[1:, :] # remove enable attr
conn_exist = jnp.any(~jnp.isnan(u_conns), axis=0)
seqs = topological_sort(nodes, conn_exist)
return seqs, nodes, u_conns
@staticmethod
def create_forward(config):
config['activation_funcs'] = [Activation.name2func[name] for name in config['activation_option_names']]
config['aggregation_funcs'] = [Aggregation.name2func[name] for name in config['aggregation_option_names']]
def act(idx, z):
"""
calculate activation function for each node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, config['activation_funcs'], z)
return res
def agg(idx, z):
"""
calculate activation function for inputs of node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
def all_nan():
return 0.
def not_all_nan():
return jax.lax.switch(idx, config['aggregation_funcs'], z)
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
def forward(inputs, transform) -> Array:
"""
jax forward for single input shaped (input_num, )
nodes, connections are a single genome
:argument inputs: (input_num, )
:argument cal_seqs: (N, )
:argument nodes: (N, 5)
:argument connections: (2, N, N)
:return (output_num, )
"""
cal_seqs, nodes, cons = transform
input_idx = config['input_idx']
output_idx = config['output_idx']
N = nodes.shape[0]
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
weights = cons[0, :]
def cond_fun(carry):
values, idx = carry
return (idx < N) & (cal_seqs[idx] != I_INT)
def body_func(carry):
values, idx = carry
i = cal_seqs[idx]
def hit():
ins = values * weights[:, i]
z = agg(nodes[i, 4], ins) # z = agg(ins)
z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
z = act(nodes[i, 3], z) # z = act(z)
new_values = values.at[i].set(z)
return new_values
def miss():
return values
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
# if jnp.isin(i, input_idx):
# values = miss()
# else:
# values = hit()
return values, idx + 1
# carry = (ini_vals, 0)
# while cond_fun(carry):
# carry = body_func(carry)
# vals, _ = carry
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
return vals[output_idx]
return forward
@staticmethod
def _mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate):
@@ -114,3 +216,7 @@ class NormalGene(BaseGene):
)
return val