change a lot a lot a lot!!!!!!!
This commit is contained in:
@@ -1,6 +1 @@
|
||||
from .base import BaseGene
|
||||
from .normal import NormalGene
|
||||
from .activation import Activation
|
||||
from .aggregation import Aggregation
|
||||
from .recurrent import RecurrentGene
|
||||
|
||||
from .normal import NormalGene, NormalGeneConfig
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class Activation:
|
||||
|
||||
name2func = {}
|
||||
|
||||
@staticmethod
|
||||
def sigmoid_act(z):
|
||||
z = jnp.clip(z * 5, -60, 60)
|
||||
return 1 / (1 + jnp.exp(-z))
|
||||
|
||||
@staticmethod
|
||||
def tanh_act(z):
|
||||
z = jnp.clip(z * 2.5, -60, 60)
|
||||
return jnp.tanh(z)
|
||||
|
||||
@staticmethod
|
||||
def sin_act(z):
|
||||
z = jnp.clip(z * 5, -60, 60)
|
||||
return jnp.sin(z)
|
||||
|
||||
@staticmethod
|
||||
def gauss_act(z):
|
||||
z = jnp.clip(z * 5, -3.4, 3.4)
|
||||
return jnp.exp(-z ** 2)
|
||||
|
||||
@staticmethod
|
||||
def relu_act(z):
|
||||
return jnp.maximum(z, 0)
|
||||
|
||||
@staticmethod
|
||||
def elu_act(z):
|
||||
return jnp.where(z > 0, z, jnp.exp(z) - 1)
|
||||
|
||||
@staticmethod
|
||||
def lelu_act(z):
|
||||
leaky = 0.005
|
||||
return jnp.where(z > 0, z, leaky * z)
|
||||
|
||||
@staticmethod
|
||||
def selu_act(z):
|
||||
lam = 1.0507009873554804934193349852946
|
||||
alpha = 1.6732632423543772848170429916717
|
||||
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
|
||||
|
||||
@staticmethod
|
||||
def softplus_act(z):
|
||||
z = jnp.clip(z * 5, -60, 60)
|
||||
return 0.2 * jnp.log(1 + jnp.exp(z))
|
||||
|
||||
@staticmethod
|
||||
def identity_act(z):
|
||||
return z
|
||||
|
||||
@staticmethod
|
||||
def clamped_act(z):
|
||||
return jnp.clip(z, -1, 1)
|
||||
|
||||
@staticmethod
|
||||
def inv_act(z):
|
||||
z = jnp.maximum(z, 1e-7)
|
||||
return 1 / z
|
||||
|
||||
@staticmethod
|
||||
def log_act(z):
|
||||
z = jnp.maximum(z, 1e-7)
|
||||
return jnp.log(z)
|
||||
|
||||
@staticmethod
|
||||
def exp_act(z):
|
||||
z = jnp.clip(z, -60, 60)
|
||||
return jnp.exp(z)
|
||||
|
||||
@staticmethod
|
||||
def abs_act(z):
|
||||
return jnp.abs(z)
|
||||
|
||||
@staticmethod
|
||||
def hat_act(z):
|
||||
return jnp.maximum(0, 1 - jnp.abs(z))
|
||||
|
||||
@staticmethod
|
||||
def square_act(z):
|
||||
return z ** 2
|
||||
|
||||
@staticmethod
|
||||
def cube_act(z):
|
||||
return z ** 3
|
||||
|
||||
Activation.name2func = {
|
||||
'sigmoid': Activation.sigmoid_act,
|
||||
'tanh': Activation.tanh_act,
|
||||
'sin': Activation.sin_act,
|
||||
'gauss': Activation.gauss_act,
|
||||
'relu': Activation.relu_act,
|
||||
'elu': Activation.elu_act,
|
||||
'lelu': Activation.lelu_act,
|
||||
'selu': Activation.selu_act,
|
||||
'softplus': Activation.softplus_act,
|
||||
'identity': Activation.identity_act,
|
||||
'clamped': Activation.clamped_act,
|
||||
'inv': Activation.inv_act,
|
||||
'log': Activation.log_act,
|
||||
'exp': Activation.exp_act,
|
||||
'abs': Activation.abs_act,
|
||||
'hat': Activation.hat_act,
|
||||
'square': Activation.square_act,
|
||||
'cube': Activation.cube_act,
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class Aggregation:
|
||||
|
||||
name2func = {}
|
||||
|
||||
@staticmethod
|
||||
def sum_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
return jnp.sum(z, axis=0)
|
||||
|
||||
@staticmethod
|
||||
def product_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), 1, z)
|
||||
return jnp.prod(z, axis=0)
|
||||
|
||||
@staticmethod
|
||||
def max_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
|
||||
return jnp.max(z, axis=0)
|
||||
|
||||
@staticmethod
|
||||
def min_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), jnp.inf, z)
|
||||
return jnp.min(z, axis=0)
|
||||
|
||||
@staticmethod
|
||||
def maxabs_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
abs_z = jnp.abs(z)
|
||||
max_abs_index = jnp.argmax(abs_z)
|
||||
return z[max_abs_index]
|
||||
|
||||
@staticmethod
|
||||
def median_agg(z):
|
||||
n = jnp.sum(~jnp.isnan(z), axis=0)
|
||||
|
||||
z = jnp.sort(z) # sort
|
||||
|
||||
idx1, idx2 = (n - 1) // 2, n // 2
|
||||
median = (z[idx1] + z[idx2]) / 2
|
||||
|
||||
return median
|
||||
|
||||
@staticmethod
|
||||
def mean_agg(z):
|
||||
aux = jnp.where(jnp.isnan(z), 0, z)
|
||||
valid_values_sum = jnp.sum(aux, axis=0)
|
||||
valid_values_count = jnp.sum(~jnp.isnan(z), axis=0)
|
||||
mean_without_zeros = valid_values_sum / valid_values_count
|
||||
return mean_without_zeros
|
||||
|
||||
|
||||
Aggregation.name2func = {
|
||||
'sum': Aggregation.sum_agg,
|
||||
'product': Aggregation.product_agg,
|
||||
'max': Aggregation.max_agg,
|
||||
'min': Aggregation.min_agg,
|
||||
'maxabs': Aggregation.maxabs_agg,
|
||||
'median': Aggregation.median_agg,
|
||||
'mean': Aggregation.mean_agg,
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
from jax import Array, numpy as jnp, vmap
|
||||
|
||||
|
||||
class BaseGene:
|
||||
node_attrs = []
|
||||
conn_attrs = []
|
||||
|
||||
@staticmethod
|
||||
def setup(state, config):
|
||||
return state
|
||||
|
||||
@staticmethod
|
||||
def new_node_attrs(state):
|
||||
return jnp.zeros(0)
|
||||
|
||||
@staticmethod
|
||||
def new_conn_attrs(state):
|
||||
return jnp.zeros(0)
|
||||
|
||||
@staticmethod
|
||||
def mutate_node(state, attrs: Array, key):
|
||||
return attrs
|
||||
|
||||
@staticmethod
|
||||
def mutate_conn(state, attrs: Array, key):
|
||||
return attrs
|
||||
|
||||
@staticmethod
|
||||
def distance_node(state, node1: Array, node2: Array):
|
||||
return node1
|
||||
|
||||
@staticmethod
|
||||
def distance_conn(state, conn1: Array, conn2: Array):
|
||||
return conn1
|
||||
|
||||
@staticmethod
|
||||
def forward_transform(state, nodes, conns):
|
||||
return nodes, conns
|
||||
|
||||
@staticmethod
|
||||
def create_forward(config):
|
||||
return None
|
||||
@@ -1,45 +1,100 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import jax
|
||||
from jax import Array, numpy as jnp
|
||||
|
||||
from .base import BaseGene
|
||||
from .activation import Activation
|
||||
from .aggregation import Aggregation
|
||||
from algorithm.utils import unflatten_connections, I_INT
|
||||
from ..genome import topological_sort
|
||||
from config import GeneConfig
|
||||
from core import Gene, Genome, State
|
||||
from utils import Activation, Aggregation, unflatten_conns, topological_sort, I_INT
|
||||
|
||||
|
||||
class NormalGene(BaseGene):
|
||||
@dataclass(frozen=True)
|
||||
class NormalGeneConfig(GeneConfig):
|
||||
bias_init_mean: float = 0.0
|
||||
bias_init_std: float = 1.0
|
||||
bias_mutate_power: float = 0.5
|
||||
bias_mutate_rate: float = 0.7
|
||||
bias_replace_rate: float = 0.1
|
||||
|
||||
response_init_mean: float = 1.0
|
||||
response_init_std: float = 0.0
|
||||
response_mutate_power: float = 0.5
|
||||
response_mutate_rate: float = 0.7
|
||||
response_replace_rate: float = 0.1
|
||||
|
||||
activation_default: str = 'sigmoid'
|
||||
activation_options: Tuple[str] = ('sigmoid',)
|
||||
activation_replace_rate: float = 0.1
|
||||
|
||||
aggregation_default: str = 'sum'
|
||||
aggregation_options: Tuple[str] = ('sum',)
|
||||
aggregation_replace_rate: float = 0.1
|
||||
|
||||
weight_init_mean: float = 0.0
|
||||
weight_init_std: float = 1.0
|
||||
weight_mutate_power: float = 0.5
|
||||
weight_mutate_rate: float = 0.8
|
||||
weight_replace_rate: float = 0.1
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.bias_init_std >= 0.0
|
||||
assert self.bias_mutate_power >= 0.0
|
||||
assert self.bias_mutate_rate >= 0.0
|
||||
assert self.bias_replace_rate >= 0.0
|
||||
|
||||
assert self.response_init_std >= 0.0
|
||||
assert self.response_mutate_power >= 0.0
|
||||
assert self.response_mutate_rate >= 0.0
|
||||
assert self.response_replace_rate >= 0.0
|
||||
|
||||
assert self.activation_default == self.activation_options[0]
|
||||
|
||||
for name in self.activation_options:
|
||||
assert name in Activation.name2func, f"Activation function: {name} not found"
|
||||
|
||||
assert self.aggregation_default == self.aggregation_options[0]
|
||||
|
||||
assert self.aggregation_default in Aggregation.name2func, \
|
||||
f"Aggregation function: {self.aggregation_default} not found"
|
||||
|
||||
for name in self.aggregation_options:
|
||||
assert name in Aggregation.name2func, f"Aggregation function: {name} not found"
|
||||
|
||||
|
||||
class NormalGene(Gene):
|
||||
node_attrs = ['bias', 'response', 'aggregation', 'activation']
|
||||
conn_attrs = ['weight']
|
||||
|
||||
@staticmethod
|
||||
def setup(state, config):
|
||||
def setup(config: NormalGeneConfig, state: State = State()):
|
||||
|
||||
return state.update(
|
||||
bias_init_mean=config['bias_init_mean'],
|
||||
bias_init_std=config['bias_init_std'],
|
||||
bias_mutate_power=config['bias_mutate_power'],
|
||||
bias_mutate_rate=config['bias_mutate_rate'],
|
||||
bias_replace_rate=config['bias_replace_rate'],
|
||||
bias_init_mean=config.bias_init_mean,
|
||||
bias_init_std=config.bias_init_std,
|
||||
bias_mutate_power=config.bias_mutate_power,
|
||||
bias_mutate_rate=config.bias_mutate_rate,
|
||||
bias_replace_rate=config.bias_replace_rate,
|
||||
|
||||
response_init_mean=config['response_init_mean'],
|
||||
response_init_std=config['response_init_std'],
|
||||
response_mutate_power=config['response_mutate_power'],
|
||||
response_mutate_rate=config['response_mutate_rate'],
|
||||
response_replace_rate=config['response_replace_rate'],
|
||||
response_init_mean=config.response_init_mean,
|
||||
response_init_std=config.response_init_std,
|
||||
response_mutate_power=config.response_mutate_power,
|
||||
response_mutate_rate=config.response_mutate_rate,
|
||||
response_replace_rate=config.response_replace_rate,
|
||||
|
||||
activation_default=config['activation_default'],
|
||||
activation_options=config['activation_options'],
|
||||
activation_replace_rate=config['activation_replace_rate'],
|
||||
activation_replace_rate=config.activation_replace_rate,
|
||||
activation_default=0,
|
||||
activation_options=jnp.arange(len(config.activation_options)),
|
||||
|
||||
aggregation_default=config['aggregation_default'],
|
||||
aggregation_options=config['aggregation_options'],
|
||||
aggregation_replace_rate=config['aggregation_replace_rate'],
|
||||
aggregation_replace_rate=config.aggregation_replace_rate,
|
||||
aggregation_default=0,
|
||||
aggregation_options=jnp.arange(len(config.aggregation_options)),
|
||||
|
||||
weight_init_mean=config['weight_init_mean'],
|
||||
weight_init_std=config['weight_init_std'],
|
||||
weight_mutate_power=config['weight_mutate_power'],
|
||||
weight_mutate_rate=config['weight_mutate_rate'],
|
||||
weight_replace_rate=config['weight_replace_rate'],
|
||||
weight_init_mean=config.weight_init_mean,
|
||||
weight_init_std=config.weight_init_std,
|
||||
weight_mutate_power=config.weight_mutate_power,
|
||||
weight_mutate_rate=config.weight_mutate_rate,
|
||||
weight_replace_rate=config.weight_replace_rate,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -84,20 +139,20 @@ class NormalGene(BaseGene):
|
||||
return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight
|
||||
|
||||
@staticmethod
|
||||
def forward_transform(state, nodes, conns):
|
||||
u_conns = unflatten_connections(nodes, conns)
|
||||
def forward_transform(state: State, genome: Genome):
|
||||
u_conns = unflatten_conns(genome.nodes, genome.conns)
|
||||
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
|
||||
|
||||
# remove enable attr
|
||||
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
||||
seqs = topological_sort(nodes, conn_enable)
|
||||
seqs = topological_sort(genome.nodes, conn_enable)
|
||||
|
||||
return seqs, nodes, u_conns
|
||||
return seqs, genome.nodes, u_conns
|
||||
|
||||
@staticmethod
|
||||
def create_forward(config):
|
||||
config['activation_funcs'] = [Activation.name2func[name] for name in config['activation_option_names']]
|
||||
config['aggregation_funcs'] = [Aggregation.name2func[name] for name in config['aggregation_option_names']]
|
||||
def create_forward(state: State, config: NormalGeneConfig):
|
||||
activation_funcs = [Activation.name2func[name] for name in config.activation_options]
|
||||
aggregation_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
|
||||
|
||||
def act(idx, z):
|
||||
"""
|
||||
@@ -105,7 +160,7 @@ class NormalGene(BaseGene):
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
# change idx from float to int
|
||||
res = jax.lax.switch(idx, config['activation_funcs'], z)
|
||||
res = jax.lax.switch(idx, activation_funcs, z)
|
||||
return res
|
||||
|
||||
def agg(idx, z):
|
||||
@@ -118,14 +173,13 @@ class NormalGene(BaseGene):
|
||||
return 0.
|
||||
|
||||
def not_all_nan():
|
||||
return jax.lax.switch(idx, config['aggregation_funcs'], z)
|
||||
return jax.lax.switch(idx, aggregation_funcs, z)
|
||||
|
||||
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
|
||||
|
||||
def forward(inputs, transform) -> Array:
|
||||
def forward(inputs, transformed) -> Array:
|
||||
"""
|
||||
jax forward for single input shaped (input_num, )
|
||||
nodes, connections are a single genome
|
||||
forward for single input shaped (input_num, )
|
||||
|
||||
:argument inputs: (input_num, )
|
||||
:argument cal_seqs: (N, )
|
||||
@@ -135,10 +189,10 @@ class NormalGene(BaseGene):
|
||||
:return (output_num, )
|
||||
"""
|
||||
|
||||
cal_seqs, nodes, cons = transform
|
||||
cal_seqs, nodes, cons = transformed
|
||||
|
||||
input_idx = config['input_idx']
|
||||
output_idx = config['output_idx']
|
||||
input_idx = state.input_idx
|
||||
output_idx = state.output_idx
|
||||
|
||||
N = nodes.shape[0]
|
||||
ini_vals = jnp.full((N,), jnp.nan)
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
import jax
|
||||
from jax import Array, numpy as jnp, vmap
|
||||
|
||||
from .normal import NormalGene
|
||||
from .activation import Activation
|
||||
from .aggregation import Aggregation
|
||||
from algorithm.utils import unflatten_connections
|
||||
|
||||
|
||||
class RecurrentGene(NormalGene):
|
||||
|
||||
@staticmethod
|
||||
def forward_transform(state, nodes, conns):
|
||||
u_conns = unflatten_connections(nodes, conns)
|
||||
|
||||
# remove un-enable connections and remove enable attr
|
||||
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
|
||||
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
||||
|
||||
return nodes, u_conns
|
||||
|
||||
@staticmethod
|
||||
def create_forward(config):
|
||||
config['activation_funcs'] = [Activation.name2func[name] for name in config['activation_option_names']]
|
||||
config['aggregation_funcs'] = [Aggregation.name2func[name] for name in config['aggregation_option_names']]
|
||||
|
||||
def act(idx, z):
|
||||
"""
|
||||
calculate activation function for each node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
# change idx from float to int
|
||||
res = jax.lax.switch(idx, config['activation_funcs'], z)
|
||||
return res
|
||||
|
||||
def agg(idx, z):
|
||||
"""
|
||||
calculate activation function for inputs of node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
|
||||
def all_nan():
|
||||
return 0.
|
||||
|
||||
def not_all_nan():
|
||||
return jax.lax.switch(idx, config['aggregation_funcs'], z)
|
||||
|
||||
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
|
||||
|
||||
batch_act, batch_agg = vmap(act), vmap(agg)
|
||||
|
||||
def forward(inputs, transform) -> Array:
|
||||
"""
|
||||
jax forward for single input shaped (input_num, )
|
||||
nodes, connections are a single genome
|
||||
|
||||
:argument inputs: (input_num, )
|
||||
:argument cal_seqs: (N, )
|
||||
:argument nodes: (N, 5)
|
||||
:argument connections: (2, N, N)
|
||||
|
||||
:return (output_num, )
|
||||
"""
|
||||
|
||||
nodes, cons = transform
|
||||
|
||||
input_idx = config['input_idx']
|
||||
output_idx = config['output_idx']
|
||||
|
||||
N = nodes.shape[0]
|
||||
vals = jnp.full((N,), 0.)
|
||||
|
||||
weights = cons[0, :]
|
||||
|
||||
def body_func(i, values):
|
||||
values = values.at[input_idx].set(inputs)
|
||||
nodes_ins = values * weights.T
|
||||
values = batch_agg(nodes[:, 4], nodes_ins) # z = agg(ins)
|
||||
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
|
||||
values = batch_act(nodes[:, 3], values) # z = act(z)
|
||||
return values
|
||||
|
||||
# for i in range(config['activate_times']):
|
||||
# vals = body_func(i, vals)
|
||||
#
|
||||
# return vals[output_idx]
|
||||
vals = jax.lax.fori_loop(0, config['activate_times'], body_func, vals)
|
||||
return vals[output_idx]
|
||||
|
||||
return forward
|
||||
Reference in New Issue
Block a user