complete normal neat algorithm
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
import jax
|
||||
from jax import Array, numpy as jnp
|
||||
|
||||
from . import BaseGene
|
||||
from .base import BaseGene
|
||||
from .activation import Activation
|
||||
from .aggregation import Aggregation
|
||||
from ..utils import unflatten_connections, I_INT
|
||||
from ..genome import topological_sort
|
||||
|
||||
|
||||
class NormalGene(BaseGene):
|
||||
@@ -70,18 +74,116 @@ class NormalGene(BaseGene):
|
||||
return jnp.array([weight])
|
||||
|
||||
@staticmethod
|
||||
def distance_node(state, array1: Array, array2: Array):
|
||||
def distance_node(state, node1: Array, node2: Array):
|
||||
# bias + response + activation + aggregation
|
||||
return jnp.abs(array1[1] - array2[1]) + jnp.abs(array1[2] - array2[2]) + \
|
||||
(array1[3] != array2[3]) + (array1[4] != array2[4])
|
||||
return jnp.abs(node1[1] - node2[1]) + jnp.abs(node1[2] - node2[2]) + \
|
||||
(node1[3] != node2[3]) + (node1[4] != node2[4])
|
||||
|
||||
@staticmethod
|
||||
def distance_conn(state, array1: Array, array2: Array):
|
||||
return (array1[2] != array2[2]) + jnp.abs(array1[3] - array2[3]) # enable + weight
|
||||
def distance_conn(state, con1: Array, con2: Array):
|
||||
return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight
|
||||
|
||||
@staticmethod
|
||||
def forward(state, array: Array):
|
||||
return array
|
||||
def forward_transform(nodes, conns):
|
||||
u_conns = unflatten_connections(nodes, conns)
|
||||
u_conns = jnp.where(jnp.isnan(u_conns[0, :]), jnp.nan, u_conns) # enable is false, then the connections is nan
|
||||
u_conns = u_conns[1:, :] # remove enable attr
|
||||
conn_exist = jnp.any(~jnp.isnan(u_conns), axis=0)
|
||||
seqs = topological_sort(nodes, conn_exist)
|
||||
return seqs, nodes, u_conns
|
||||
|
||||
@staticmethod
|
||||
def create_forward(config):
|
||||
config['activation_funcs'] = [Activation.name2func[name] for name in config['activation_option_names']]
|
||||
config['aggregation_funcs'] = [Aggregation.name2func[name] for name in config['aggregation_option_names']]
|
||||
|
||||
def act(idx, z):
|
||||
"""
|
||||
calculate activation function for each node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
# change idx from float to int
|
||||
res = jax.lax.switch(idx, config['activation_funcs'], z)
|
||||
return res
|
||||
|
||||
def agg(idx, z):
|
||||
"""
|
||||
calculate activation function for inputs of node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
|
||||
def all_nan():
|
||||
return 0.
|
||||
|
||||
def not_all_nan():
|
||||
return jax.lax.switch(idx, config['aggregation_funcs'], z)
|
||||
|
||||
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
|
||||
|
||||
def forward(inputs, transform) -> Array:
|
||||
"""
|
||||
jax forward for single input shaped (input_num, )
|
||||
nodes, connections are a single genome
|
||||
|
||||
:argument inputs: (input_num, )
|
||||
:argument cal_seqs: (N, )
|
||||
:argument nodes: (N, 5)
|
||||
:argument connections: (2, N, N)
|
||||
|
||||
:return (output_num, )
|
||||
"""
|
||||
|
||||
cal_seqs, nodes, cons = transform
|
||||
|
||||
input_idx = config['input_idx']
|
||||
output_idx = config['output_idx']
|
||||
|
||||
N = nodes.shape[0]
|
||||
ini_vals = jnp.full((N,), jnp.nan)
|
||||
ini_vals = ini_vals.at[input_idx].set(inputs)
|
||||
|
||||
weights = cons[0, :]
|
||||
|
||||
def cond_fun(carry):
|
||||
values, idx = carry
|
||||
return (idx < N) & (cal_seqs[idx] != I_INT)
|
||||
|
||||
def body_func(carry):
|
||||
values, idx = carry
|
||||
i = cal_seqs[idx]
|
||||
|
||||
def hit():
|
||||
ins = values * weights[:, i]
|
||||
z = agg(nodes[i, 4], ins) # z = agg(ins)
|
||||
z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
|
||||
z = act(nodes[i, 3], z) # z = act(z)
|
||||
|
||||
new_values = values.at[i].set(z)
|
||||
return new_values
|
||||
|
||||
def miss():
|
||||
return values
|
||||
|
||||
# the val of input nodes is obtained by the task, not by calculation
|
||||
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
|
||||
|
||||
# if jnp.isin(i, input_idx):
|
||||
# values = miss()
|
||||
# else:
|
||||
# values = hit()
|
||||
|
||||
return values, idx + 1
|
||||
|
||||
# carry = (ini_vals, 0)
|
||||
# while cond_fun(carry):
|
||||
# carry = body_func(carry)
|
||||
# vals, _ = carry
|
||||
|
||||
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
|
||||
|
||||
return vals[output_idx]
|
||||
|
||||
return forward
|
||||
|
||||
@staticmethod
|
||||
def _mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate):
|
||||
@@ -114,3 +216,7 @@ class NormalGene(BaseGene):
|
||||
)
|
||||
|
||||
return val
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user