change a lot a lot a lot!!!!!!!

This commit is contained in:
wls2002
2023-07-24 02:16:02 +08:00
parent 48f90c7eef
commit ac295c1921
49 changed files with 1138 additions and 1460 deletions

View File

@@ -1,6 +1 @@
from .base import BaseGene
from .normal import NormalGene
from .activation import Activation
from .aggregation import Aggregation
from .recurrent import RecurrentGene
from .normal import NormalGene, NormalGeneConfig

View File

@@ -1,110 +0,0 @@
import jax.numpy as jnp
class Activation:
name2func = {}
@staticmethod
def sigmoid_act(z):
z = jnp.clip(z * 5, -60, 60)
return 1 / (1 + jnp.exp(-z))
@staticmethod
def tanh_act(z):
z = jnp.clip(z * 2.5, -60, 60)
return jnp.tanh(z)
@staticmethod
def sin_act(z):
z = jnp.clip(z * 5, -60, 60)
return jnp.sin(z)
@staticmethod
def gauss_act(z):
z = jnp.clip(z * 5, -3.4, 3.4)
return jnp.exp(-z ** 2)
@staticmethod
def relu_act(z):
return jnp.maximum(z, 0)
@staticmethod
def elu_act(z):
return jnp.where(z > 0, z, jnp.exp(z) - 1)
@staticmethod
def lelu_act(z):
leaky = 0.005
return jnp.where(z > 0, z, leaky * z)
@staticmethod
def selu_act(z):
lam = 1.0507009873554804934193349852946
alpha = 1.6732632423543772848170429916717
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
@staticmethod
def softplus_act(z):
z = jnp.clip(z * 5, -60, 60)
return 0.2 * jnp.log(1 + jnp.exp(z))
@staticmethod
def identity_act(z):
return z
@staticmethod
def clamped_act(z):
return jnp.clip(z, -1, 1)
@staticmethod
def inv_act(z):
z = jnp.maximum(z, 1e-7)
return 1 / z
@staticmethod
def log_act(z):
z = jnp.maximum(z, 1e-7)
return jnp.log(z)
@staticmethod
def exp_act(z):
z = jnp.clip(z, -60, 60)
return jnp.exp(z)
@staticmethod
def abs_act(z):
return jnp.abs(z)
@staticmethod
def hat_act(z):
return jnp.maximum(0, 1 - jnp.abs(z))
@staticmethod
def square_act(z):
return z ** 2
@staticmethod
def cube_act(z):
return z ** 3
Activation.name2func = {
'sigmoid': Activation.sigmoid_act,
'tanh': Activation.tanh_act,
'sin': Activation.sin_act,
'gauss': Activation.gauss_act,
'relu': Activation.relu_act,
'elu': Activation.elu_act,
'lelu': Activation.lelu_act,
'selu': Activation.selu_act,
'softplus': Activation.softplus_act,
'identity': Activation.identity_act,
'clamped': Activation.clamped_act,
'inv': Activation.inv_act,
'log': Activation.log_act,
'exp': Activation.exp_act,
'abs': Activation.abs_act,
'hat': Activation.hat_act,
'square': Activation.square_act,
'cube': Activation.cube_act,
}

View File

@@ -1,63 +0,0 @@
import jax.numpy as jnp
class Aggregation:
name2func = {}
@staticmethod
def sum_agg(z):
z = jnp.where(jnp.isnan(z), 0, z)
return jnp.sum(z, axis=0)
@staticmethod
def product_agg(z):
z = jnp.where(jnp.isnan(z), 1, z)
return jnp.prod(z, axis=0)
@staticmethod
def max_agg(z):
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
return jnp.max(z, axis=0)
@staticmethod
def min_agg(z):
z = jnp.where(jnp.isnan(z), jnp.inf, z)
return jnp.min(z, axis=0)
@staticmethod
def maxabs_agg(z):
z = jnp.where(jnp.isnan(z), 0, z)
abs_z = jnp.abs(z)
max_abs_index = jnp.argmax(abs_z)
return z[max_abs_index]
@staticmethod
def median_agg(z):
n = jnp.sum(~jnp.isnan(z), axis=0)
z = jnp.sort(z) # sort
idx1, idx2 = (n - 1) // 2, n // 2
median = (z[idx1] + z[idx2]) / 2
return median
@staticmethod
def mean_agg(z):
aux = jnp.where(jnp.isnan(z), 0, z)
valid_values_sum = jnp.sum(aux, axis=0)
valid_values_count = jnp.sum(~jnp.isnan(z), axis=0)
mean_without_zeros = valid_values_sum / valid_values_count
return mean_without_zeros
Aggregation.name2func = {
'sum': Aggregation.sum_agg,
'product': Aggregation.product_agg,
'max': Aggregation.max_agg,
'min': Aggregation.min_agg,
'maxabs': Aggregation.maxabs_agg,
'median': Aggregation.median_agg,
'mean': Aggregation.mean_agg,
}

View File

@@ -1,42 +0,0 @@
from jax import Array, numpy as jnp, vmap
class BaseGene:
node_attrs = []
conn_attrs = []
@staticmethod
def setup(state, config):
return state
@staticmethod
def new_node_attrs(state):
return jnp.zeros(0)
@staticmethod
def new_conn_attrs(state):
return jnp.zeros(0)
@staticmethod
def mutate_node(state, attrs: Array, key):
return attrs
@staticmethod
def mutate_conn(state, attrs: Array, key):
return attrs
@staticmethod
def distance_node(state, node1: Array, node2: Array):
return node1
@staticmethod
def distance_conn(state, conn1: Array, conn2: Array):
return conn1
@staticmethod
def forward_transform(state, nodes, conns):
return nodes, conns
@staticmethod
def create_forward(config):
return None

View File

@@ -1,45 +1,100 @@
from dataclasses import dataclass
from typing import Tuple
import jax
from jax import Array, numpy as jnp
from .base import BaseGene
from .activation import Activation
from .aggregation import Aggregation
from algorithm.utils import unflatten_connections, I_INT
from ..genome import topological_sort
from config import GeneConfig
from core import Gene, Genome, State
from utils import Activation, Aggregation, unflatten_conns, topological_sort, I_INT
class NormalGene(BaseGene):
@dataclass(frozen=True)
class NormalGeneConfig(GeneConfig):
bias_init_mean: float = 0.0
bias_init_std: float = 1.0
bias_mutate_power: float = 0.5
bias_mutate_rate: float = 0.7
bias_replace_rate: float = 0.1
response_init_mean: float = 1.0
response_init_std: float = 0.0
response_mutate_power: float = 0.5
response_mutate_rate: float = 0.7
response_replace_rate: float = 0.1
activation_default: str = 'sigmoid'
activation_options: Tuple[str] = ('sigmoid',)
activation_replace_rate: float = 0.1
aggregation_default: str = 'sum'
aggregation_options: Tuple[str] = ('sum',)
aggregation_replace_rate: float = 0.1
weight_init_mean: float = 0.0
weight_init_std: float = 1.0
weight_mutate_power: float = 0.5
weight_mutate_rate: float = 0.8
weight_replace_rate: float = 0.1
def __post_init__(self):
assert self.bias_init_std >= 0.0
assert self.bias_mutate_power >= 0.0
assert self.bias_mutate_rate >= 0.0
assert self.bias_replace_rate >= 0.0
assert self.response_init_std >= 0.0
assert self.response_mutate_power >= 0.0
assert self.response_mutate_rate >= 0.0
assert self.response_replace_rate >= 0.0
assert self.activation_default == self.activation_options[0]
for name in self.activation_options:
assert name in Activation.name2func, f"Activation function: {name} not found"
assert self.aggregation_default == self.aggregation_options[0]
assert self.aggregation_default in Aggregation.name2func, \
f"Aggregation function: {self.aggregation_default} not found"
for name in self.aggregation_options:
assert name in Aggregation.name2func, f"Aggregation function: {name} not found"
class NormalGene(Gene):
node_attrs = ['bias', 'response', 'aggregation', 'activation']
conn_attrs = ['weight']
@staticmethod
def setup(state, config):
def setup(config: NormalGeneConfig, state: State = State()):
return state.update(
bias_init_mean=config['bias_init_mean'],
bias_init_std=config['bias_init_std'],
bias_mutate_power=config['bias_mutate_power'],
bias_mutate_rate=config['bias_mutate_rate'],
bias_replace_rate=config['bias_replace_rate'],
bias_init_mean=config.bias_init_mean,
bias_init_std=config.bias_init_std,
bias_mutate_power=config.bias_mutate_power,
bias_mutate_rate=config.bias_mutate_rate,
bias_replace_rate=config.bias_replace_rate,
response_init_mean=config['response_init_mean'],
response_init_std=config['response_init_std'],
response_mutate_power=config['response_mutate_power'],
response_mutate_rate=config['response_mutate_rate'],
response_replace_rate=config['response_replace_rate'],
response_init_mean=config.response_init_mean,
response_init_std=config.response_init_std,
response_mutate_power=config.response_mutate_power,
response_mutate_rate=config.response_mutate_rate,
response_replace_rate=config.response_replace_rate,
activation_default=config['activation_default'],
activation_options=config['activation_options'],
activation_replace_rate=config['activation_replace_rate'],
activation_replace_rate=config.activation_replace_rate,
activation_default=0,
activation_options=jnp.arange(len(config.activation_options)),
aggregation_default=config['aggregation_default'],
aggregation_options=config['aggregation_options'],
aggregation_replace_rate=config['aggregation_replace_rate'],
aggregation_replace_rate=config.aggregation_replace_rate,
aggregation_default=0,
aggregation_options=jnp.arange(len(config.aggregation_options)),
weight_init_mean=config['weight_init_mean'],
weight_init_std=config['weight_init_std'],
weight_mutate_power=config['weight_mutate_power'],
weight_mutate_rate=config['weight_mutate_rate'],
weight_replace_rate=config['weight_replace_rate'],
weight_init_mean=config.weight_init_mean,
weight_init_std=config.weight_init_std,
weight_mutate_power=config.weight_mutate_power,
weight_mutate_rate=config.weight_mutate_rate,
weight_replace_rate=config.weight_replace_rate,
)
@staticmethod
@@ -84,20 +139,20 @@ class NormalGene(BaseGene):
return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight
@staticmethod
def forward_transform(state, nodes, conns):
u_conns = unflatten_connections(nodes, conns)
def forward_transform(state: State, genome: Genome):
u_conns = unflatten_conns(genome.nodes, genome.conns)
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
# remove enable attr
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
seqs = topological_sort(nodes, conn_enable)
seqs = topological_sort(genome.nodes, conn_enable)
return seqs, nodes, u_conns
return seqs, genome.nodes, u_conns
@staticmethod
def create_forward(config):
config['activation_funcs'] = [Activation.name2func[name] for name in config['activation_option_names']]
config['aggregation_funcs'] = [Aggregation.name2func[name] for name in config['aggregation_option_names']]
def create_forward(state: State, config: NormalGeneConfig):
activation_funcs = [Activation.name2func[name] for name in config.activation_options]
aggregation_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
def act(idx, z):
"""
@@ -105,7 +160,7 @@ class NormalGene(BaseGene):
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, config['activation_funcs'], z)
res = jax.lax.switch(idx, activation_funcs, z)
return res
def agg(idx, z):
@@ -118,14 +173,13 @@ class NormalGene(BaseGene):
return 0.
def not_all_nan():
return jax.lax.switch(idx, config['aggregation_funcs'], z)
return jax.lax.switch(idx, aggregation_funcs, z)
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
def forward(inputs, transform) -> Array:
def forward(inputs, transformed) -> Array:
"""
jax forward for single input shaped (input_num, )
nodes, connections are a single genome
forward for single input shaped (input_num, )
:argument inputs: (input_num, )
:argument cal_seqs: (N, )
@@ -135,10 +189,10 @@ class NormalGene(BaseGene):
:return (output_num, )
"""
cal_seqs, nodes, cons = transform
cal_seqs, nodes, cons = transformed
input_idx = config['input_idx']
output_idx = config['output_idx']
input_idx = state.input_idx
output_idx = state.output_idx
N = nodes.shape[0]
ini_vals = jnp.full((N,), jnp.nan)

View File

@@ -1,90 +0,0 @@
import jax
from jax import Array, numpy as jnp, vmap
from .normal import NormalGene
from .activation import Activation
from .aggregation import Aggregation
from algorithm.utils import unflatten_connections
class RecurrentGene(NormalGene):
@staticmethod
def forward_transform(state, nodes, conns):
u_conns = unflatten_connections(nodes, conns)
# remove un-enable connections and remove enable attr
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
return nodes, u_conns
@staticmethod
def create_forward(config):
config['activation_funcs'] = [Activation.name2func[name] for name in config['activation_option_names']]
config['aggregation_funcs'] = [Aggregation.name2func[name] for name in config['aggregation_option_names']]
def act(idx, z):
"""
calculate activation function for each node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, config['activation_funcs'], z)
return res
def agg(idx, z):
"""
calculate activation function for inputs of node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
def all_nan():
return 0.
def not_all_nan():
return jax.lax.switch(idx, config['aggregation_funcs'], z)
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
batch_act, batch_agg = vmap(act), vmap(agg)
def forward(inputs, transform) -> Array:
"""
jax forward for single input shaped (input_num, )
nodes, connections are a single genome
:argument inputs: (input_num, )
:argument cal_seqs: (N, )
:argument nodes: (N, 5)
:argument connections: (2, N, N)
:return (output_num, )
"""
nodes, cons = transform
input_idx = config['input_idx']
output_idx = config['output_idx']
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
weights = cons[0, :]
def body_func(i, values):
values = values.at[input_idx].set(inputs)
nodes_ins = values * weights.T
values = batch_agg(nodes[:, 4], nodes_ins) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(nodes[:, 3], values) # z = act(z)
return values
# for i in range(config['activate_times']):
# vals = body_func(i, vals)
#
# return vals[output_idx]
vals = jax.lax.fori_loop(0, config['activate_times'], body_func, vals)
return vals[output_idx]
return forward