change a lot a lot a lot!!!!!!!
This commit is contained in:
@@ -1,45 +1,100 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import jax
|
||||
from jax import Array, numpy as jnp
|
||||
|
||||
from .base import BaseGene
|
||||
from .activation import Activation
|
||||
from .aggregation import Aggregation
|
||||
from algorithm.utils import unflatten_connections, I_INT
|
||||
from ..genome import topological_sort
|
||||
from config import GeneConfig
|
||||
from core import Gene, Genome, State
|
||||
from utils import Activation, Aggregation, unflatten_conns, topological_sort, I_INT
|
||||
|
||||
|
||||
class NormalGene(BaseGene):
|
||||
@dataclass(frozen=True)
|
||||
class NormalGeneConfig(GeneConfig):
|
||||
bias_init_mean: float = 0.0
|
||||
bias_init_std: float = 1.0
|
||||
bias_mutate_power: float = 0.5
|
||||
bias_mutate_rate: float = 0.7
|
||||
bias_replace_rate: float = 0.1
|
||||
|
||||
response_init_mean: float = 1.0
|
||||
response_init_std: float = 0.0
|
||||
response_mutate_power: float = 0.5
|
||||
response_mutate_rate: float = 0.7
|
||||
response_replace_rate: float = 0.1
|
||||
|
||||
activation_default: str = 'sigmoid'
|
||||
activation_options: Tuple[str] = ('sigmoid',)
|
||||
activation_replace_rate: float = 0.1
|
||||
|
||||
aggregation_default: str = 'sum'
|
||||
aggregation_options: Tuple[str] = ('sum',)
|
||||
aggregation_replace_rate: float = 0.1
|
||||
|
||||
weight_init_mean: float = 0.0
|
||||
weight_init_std: float = 1.0
|
||||
weight_mutate_power: float = 0.5
|
||||
weight_mutate_rate: float = 0.8
|
||||
weight_replace_rate: float = 0.1
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.bias_init_std >= 0.0
|
||||
assert self.bias_mutate_power >= 0.0
|
||||
assert self.bias_mutate_rate >= 0.0
|
||||
assert self.bias_replace_rate >= 0.0
|
||||
|
||||
assert self.response_init_std >= 0.0
|
||||
assert self.response_mutate_power >= 0.0
|
||||
assert self.response_mutate_rate >= 0.0
|
||||
assert self.response_replace_rate >= 0.0
|
||||
|
||||
assert self.activation_default == self.activation_options[0]
|
||||
|
||||
for name in self.activation_options:
|
||||
assert name in Activation.name2func, f"Activation function: {name} not found"
|
||||
|
||||
assert self.aggregation_default == self.aggregation_options[0]
|
||||
|
||||
assert self.aggregation_default in Aggregation.name2func, \
|
||||
f"Aggregation function: {self.aggregation_default} not found"
|
||||
|
||||
for name in self.aggregation_options:
|
||||
assert name in Aggregation.name2func, f"Aggregation function: {name} not found"
|
||||
|
||||
|
||||
class NormalGene(Gene):
|
||||
node_attrs = ['bias', 'response', 'aggregation', 'activation']
|
||||
conn_attrs = ['weight']
|
||||
|
||||
@staticmethod
|
||||
def setup(state, config):
|
||||
def setup(config: NormalGeneConfig, state: State = State()):
|
||||
|
||||
return state.update(
|
||||
bias_init_mean=config['bias_init_mean'],
|
||||
bias_init_std=config['bias_init_std'],
|
||||
bias_mutate_power=config['bias_mutate_power'],
|
||||
bias_mutate_rate=config['bias_mutate_rate'],
|
||||
bias_replace_rate=config['bias_replace_rate'],
|
||||
bias_init_mean=config.bias_init_mean,
|
||||
bias_init_std=config.bias_init_std,
|
||||
bias_mutate_power=config.bias_mutate_power,
|
||||
bias_mutate_rate=config.bias_mutate_rate,
|
||||
bias_replace_rate=config.bias_replace_rate,
|
||||
|
||||
response_init_mean=config['response_init_mean'],
|
||||
response_init_std=config['response_init_std'],
|
||||
response_mutate_power=config['response_mutate_power'],
|
||||
response_mutate_rate=config['response_mutate_rate'],
|
||||
response_replace_rate=config['response_replace_rate'],
|
||||
response_init_mean=config.response_init_mean,
|
||||
response_init_std=config.response_init_std,
|
||||
response_mutate_power=config.response_mutate_power,
|
||||
response_mutate_rate=config.response_mutate_rate,
|
||||
response_replace_rate=config.response_replace_rate,
|
||||
|
||||
activation_default=config['activation_default'],
|
||||
activation_options=config['activation_options'],
|
||||
activation_replace_rate=config['activation_replace_rate'],
|
||||
activation_replace_rate=config.activation_replace_rate,
|
||||
activation_default=0,
|
||||
activation_options=jnp.arange(len(config.activation_options)),
|
||||
|
||||
aggregation_default=config['aggregation_default'],
|
||||
aggregation_options=config['aggregation_options'],
|
||||
aggregation_replace_rate=config['aggregation_replace_rate'],
|
||||
aggregation_replace_rate=config.aggregation_replace_rate,
|
||||
aggregation_default=0,
|
||||
aggregation_options=jnp.arange(len(config.aggregation_options)),
|
||||
|
||||
weight_init_mean=config['weight_init_mean'],
|
||||
weight_init_std=config['weight_init_std'],
|
||||
weight_mutate_power=config['weight_mutate_power'],
|
||||
weight_mutate_rate=config['weight_mutate_rate'],
|
||||
weight_replace_rate=config['weight_replace_rate'],
|
||||
weight_init_mean=config.weight_init_mean,
|
||||
weight_init_std=config.weight_init_std,
|
||||
weight_mutate_power=config.weight_mutate_power,
|
||||
weight_mutate_rate=config.weight_mutate_rate,
|
||||
weight_replace_rate=config.weight_replace_rate,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -84,20 +139,20 @@ class NormalGene(BaseGene):
|
||||
return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight
|
||||
|
||||
@staticmethod
|
||||
def forward_transform(state, nodes, conns):
|
||||
u_conns = unflatten_connections(nodes, conns)
|
||||
def forward_transform(state: State, genome: Genome):
|
||||
u_conns = unflatten_conns(genome.nodes, genome.conns)
|
||||
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
|
||||
|
||||
# remove enable attr
|
||||
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
||||
seqs = topological_sort(nodes, conn_enable)
|
||||
seqs = topological_sort(genome.nodes, conn_enable)
|
||||
|
||||
return seqs, nodes, u_conns
|
||||
return seqs, genome.nodes, u_conns
|
||||
|
||||
@staticmethod
|
||||
def create_forward(config):
|
||||
config['activation_funcs'] = [Activation.name2func[name] for name in config['activation_option_names']]
|
||||
config['aggregation_funcs'] = [Aggregation.name2func[name] for name in config['aggregation_option_names']]
|
||||
def create_forward(state: State, config: NormalGeneConfig):
|
||||
activation_funcs = [Activation.name2func[name] for name in config.activation_options]
|
||||
aggregation_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
|
||||
|
||||
def act(idx, z):
|
||||
"""
|
||||
@@ -105,7 +160,7 @@ class NormalGene(BaseGene):
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
# change idx from float to int
|
||||
res = jax.lax.switch(idx, config['activation_funcs'], z)
|
||||
res = jax.lax.switch(idx, activation_funcs, z)
|
||||
return res
|
||||
|
||||
def agg(idx, z):
|
||||
@@ -118,14 +173,13 @@ class NormalGene(BaseGene):
|
||||
return 0.
|
||||
|
||||
def not_all_nan():
|
||||
return jax.lax.switch(idx, config['aggregation_funcs'], z)
|
||||
return jax.lax.switch(idx, aggregation_funcs, z)
|
||||
|
||||
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
|
||||
|
||||
def forward(inputs, transform) -> Array:
|
||||
def forward(inputs, transformed) -> Array:
|
||||
"""
|
||||
jax forward for single input shaped (input_num, )
|
||||
nodes, connections are a single genome
|
||||
forward for single input shaped (input_num, )
|
||||
|
||||
:argument inputs: (input_num, )
|
||||
:argument cal_seqs: (N, )
|
||||
@@ -135,10 +189,10 @@ class NormalGene(BaseGene):
|
||||
:return (output_num, )
|
||||
"""
|
||||
|
||||
cal_seqs, nodes, cons = transform
|
||||
cal_seqs, nodes, cons = transformed
|
||||
|
||||
input_idx = config['input_idx']
|
||||
output_idx = config['output_idx']
|
||||
input_idx = state.input_idx
|
||||
output_idx = state.output_idx
|
||||
|
||||
N = nodes.shape[0]
|
||||
ini_vals = jnp.full((N,), jnp.nan)
|
||||
|
||||
Reference in New Issue
Block a user