change a lot a lot a lot!!!!!!!

This commit is contained in:
wls2002
2023-07-24 02:16:02 +08:00
parent 48f90c7eef
commit ac295c1921
49 changed files with 1138 additions and 1460 deletions

View File

@@ -1,45 +1,100 @@
from dataclasses import dataclass
from typing import Tuple
import jax
from jax import Array, numpy as jnp
from .base import BaseGene
from .activation import Activation
from .aggregation import Aggregation
from algorithm.utils import unflatten_connections, I_INT
from ..genome import topological_sort
from config import GeneConfig
from core import Gene, Genome, State
from utils import Activation, Aggregation, unflatten_conns, topological_sort, I_INT
class NormalGene(BaseGene):
@dataclass(frozen=True)
class NormalGeneConfig(GeneConfig):
bias_init_mean: float = 0.0
bias_init_std: float = 1.0
bias_mutate_power: float = 0.5
bias_mutate_rate: float = 0.7
bias_replace_rate: float = 0.1
response_init_mean: float = 1.0
response_init_std: float = 0.0
response_mutate_power: float = 0.5
response_mutate_rate: float = 0.7
response_replace_rate: float = 0.1
activation_default: str = 'sigmoid'
activation_options: Tuple[str] = ('sigmoid',)
activation_replace_rate: float = 0.1
aggregation_default: str = 'sum'
aggregation_options: Tuple[str] = ('sum',)
aggregation_replace_rate: float = 0.1
weight_init_mean: float = 0.0
weight_init_std: float = 1.0
weight_mutate_power: float = 0.5
weight_mutate_rate: float = 0.8
weight_replace_rate: float = 0.1
def __post_init__(self):
assert self.bias_init_std >= 0.0
assert self.bias_mutate_power >= 0.0
assert self.bias_mutate_rate >= 0.0
assert self.bias_replace_rate >= 0.0
assert self.response_init_std >= 0.0
assert self.response_mutate_power >= 0.0
assert self.response_mutate_rate >= 0.0
assert self.response_replace_rate >= 0.0
assert self.activation_default == self.activation_options[0]
for name in self.activation_options:
assert name in Activation.name2func, f"Activation function: {name} not found"
assert self.aggregation_default == self.aggregation_options[0]
assert self.aggregation_default in Aggregation.name2func, \
f"Aggregation function: {self.aggregation_default} not found"
for name in self.aggregation_options:
assert name in Aggregation.name2func, f"Aggregation function: {name} not found"
class NormalGene(Gene):
node_attrs = ['bias', 'response', 'aggregation', 'activation']
conn_attrs = ['weight']
@staticmethod
def setup(state, config):
def setup(config: NormalGeneConfig, state: State = State()):
return state.update(
bias_init_mean=config['bias_init_mean'],
bias_init_std=config['bias_init_std'],
bias_mutate_power=config['bias_mutate_power'],
bias_mutate_rate=config['bias_mutate_rate'],
bias_replace_rate=config['bias_replace_rate'],
bias_init_mean=config.bias_init_mean,
bias_init_std=config.bias_init_std,
bias_mutate_power=config.bias_mutate_power,
bias_mutate_rate=config.bias_mutate_rate,
bias_replace_rate=config.bias_replace_rate,
response_init_mean=config['response_init_mean'],
response_init_std=config['response_init_std'],
response_mutate_power=config['response_mutate_power'],
response_mutate_rate=config['response_mutate_rate'],
response_replace_rate=config['response_replace_rate'],
response_init_mean=config.response_init_mean,
response_init_std=config.response_init_std,
response_mutate_power=config.response_mutate_power,
response_mutate_rate=config.response_mutate_rate,
response_replace_rate=config.response_replace_rate,
activation_default=config['activation_default'],
activation_options=config['activation_options'],
activation_replace_rate=config['activation_replace_rate'],
activation_replace_rate=config.activation_replace_rate,
activation_default=0,
activation_options=jnp.arange(len(config.activation_options)),
aggregation_default=config['aggregation_default'],
aggregation_options=config['aggregation_options'],
aggregation_replace_rate=config['aggregation_replace_rate'],
aggregation_replace_rate=config.aggregation_replace_rate,
aggregation_default=0,
aggregation_options=jnp.arange(len(config.aggregation_options)),
weight_init_mean=config['weight_init_mean'],
weight_init_std=config['weight_init_std'],
weight_mutate_power=config['weight_mutate_power'],
weight_mutate_rate=config['weight_mutate_rate'],
weight_replace_rate=config['weight_replace_rate'],
weight_init_mean=config.weight_init_mean,
weight_init_std=config.weight_init_std,
weight_mutate_power=config.weight_mutate_power,
weight_mutate_rate=config.weight_mutate_rate,
weight_replace_rate=config.weight_replace_rate,
)
@staticmethod
@@ -84,20 +139,20 @@ class NormalGene(BaseGene):
return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight
@staticmethod
def forward_transform(state, nodes, conns):
u_conns = unflatten_connections(nodes, conns)
def forward_transform(state: State, genome: Genome):
u_conns = unflatten_conns(genome.nodes, genome.conns)
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
# remove enable attr
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
seqs = topological_sort(nodes, conn_enable)
seqs = topological_sort(genome.nodes, conn_enable)
return seqs, nodes, u_conns
return seqs, genome.nodes, u_conns
@staticmethod
def create_forward(config):
config['activation_funcs'] = [Activation.name2func[name] for name in config['activation_option_names']]
config['aggregation_funcs'] = [Aggregation.name2func[name] for name in config['aggregation_option_names']]
def create_forward(state: State, config: NormalGeneConfig):
activation_funcs = [Activation.name2func[name] for name in config.activation_options]
aggregation_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
def act(idx, z):
"""
@@ -105,7 +160,7 @@ class NormalGene(BaseGene):
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, config['activation_funcs'], z)
res = jax.lax.switch(idx, activation_funcs, z)
return res
def agg(idx, z):
@@ -118,14 +173,13 @@ class NormalGene(BaseGene):
return 0.
def not_all_nan():
return jax.lax.switch(idx, config['aggregation_funcs'], z)
return jax.lax.switch(idx, aggregation_funcs, z)
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
def forward(inputs, transform) -> Array:
def forward(inputs, transformed) -> Array:
"""
jax forward for single input shaped (input_num, )
nodes, connections are a single genome
forward for single input shaped (input_num, )
:argument inputs: (input_num, )
:argument cal_seqs: (N, )
@@ -135,10 +189,10 @@ class NormalGene(BaseGene):
:return (output_num, )
"""
cal_seqs, nodes, cons = transform
cal_seqs, nodes, cons = transformed
input_idx = config['input_idx']
output_idx = config['output_idx']
input_idx = state.input_idx
output_idx = state.output_idx
N = nodes.shape[0]
ini_vals = jnp.full((N,), jnp.nan)