Files
tensorneat-mend/algorithm/neat/gene/normal.py
2023-08-02 15:02:08 +08:00

223 lines
8.2 KiB
Python

from dataclasses import dataclass
from typing import Tuple
import jax
from jax import Array, numpy as jnp
from config import GeneConfig
from core import Gene, Genome, State
from utils import Activation, Aggregation, unflatten_conns, topological_sort, I_INT, act, agg
@dataclass(frozen=True)
class NormalGeneConfig(GeneConfig):
bias_init_mean: float = 0.0
bias_init_std: float = 1.0
bias_mutate_power: float = 0.5
bias_mutate_rate: float = 0.7
bias_replace_rate: float = 0.1
response_init_mean: float = 1.0
response_init_std: float = 0.0
response_mutate_power: float = 0.5
response_mutate_rate: float = 0.7
response_replace_rate: float = 0.1
activation_default: str = 'sigmoid'
activation_options: Tuple = ('sigmoid',)
activation_replace_rate: float = 0.1
aggregation_default: str = 'sum'
aggregation_options: Tuple = ('sum',)
aggregation_replace_rate: float = 0.1
weight_init_mean: float = 0.0
weight_init_std: float = 1.0
weight_mutate_power: float = 0.5
weight_mutate_rate: float = 0.8
weight_replace_rate: float = 0.1
def __post_init__(self):
assert self.bias_init_std >= 0.0
assert self.bias_mutate_power >= 0.0
assert self.bias_mutate_rate >= 0.0
assert self.bias_replace_rate >= 0.0
assert self.response_init_std >= 0.0
assert self.response_mutate_power >= 0.0
assert self.response_mutate_rate >= 0.0
assert self.response_replace_rate >= 0.0
assert self.activation_default == self.activation_options[0]
for name in self.activation_options:
assert name in Activation.name2func, f"Activation function: {name} not found"
assert self.aggregation_default == self.aggregation_options[0]
assert self.aggregation_default in Aggregation.name2func, \
f"Aggregation function: {self.aggregation_default} not found"
for name in self.aggregation_options:
assert name in Aggregation.name2func, f"Aggregation function: {name} not found"
class NormalGene(Gene):
node_attrs = ['bias', 'response', 'aggregation', 'activation']
conn_attrs = ['weight']
def __init__(self, config: NormalGeneConfig):
self.config = config
self.act_funcs = [Activation.name2func[name] for name in config.activation_options]
self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
def setup(self, state: State = State()):
return state.update(
bias_init_mean=self.config.bias_init_mean,
bias_init_std=self.config.bias_init_std,
bias_mutate_power=self.config.bias_mutate_power,
bias_mutate_rate=self.config.bias_mutate_rate,
bias_replace_rate=self.config.bias_replace_rate,
response_init_mean=self.config.response_init_mean,
response_init_std=self.config.response_init_std,
response_mutate_power=self.config.response_mutate_power,
response_mutate_rate=self.config.response_mutate_rate,
response_replace_rate=self.config.response_replace_rate,
activation_replace_rate=self.config.activation_replace_rate,
activation_default=0,
activation_options=jnp.arange(len(self.config.activation_options)),
aggregation_replace_rate=self.config.aggregation_replace_rate,
aggregation_default=0,
aggregation_options=jnp.arange(len(self.config.aggregation_options)),
weight_init_mean=self.config.weight_init_mean,
weight_init_std=self.config.weight_init_std,
weight_mutate_power=self.config.weight_mutate_power,
weight_mutate_rate=self.config.weight_mutate_rate,
weight_replace_rate=self.config.weight_replace_rate,
)
def update(self, state):
pass
def new_node_attrs(self, state):
return jnp.array([state.bias_init_mean, state.response_init_mean,
state.activation_default, state.aggregation_default])
def new_conn_attrs(self, state):
return jnp.array([state.weight_init_mean])
def mutate_node(self, state, key, attrs: Array):
k1, k2, k3, k4 = jax.random.split(key, num=4)
bias = NormalGene._mutate_float(k1, attrs[0], state.bias_init_mean, state.bias_init_std,
state.bias_mutate_power, state.bias_mutate_rate, state.bias_replace_rate)
res = NormalGene._mutate_float(k2, attrs[1], state.response_init_mean, state.response_init_std,
state.response_mutate_power, state.response_mutate_rate,
state.response_replace_rate)
act = NormalGene._mutate_int(k3, attrs[2], state.activation_options, state.activation_replace_rate)
agg = NormalGene._mutate_int(k4, attrs[3], state.aggregation_options, state.aggregation_replace_rate)
return jnp.array([bias, res, act, agg])
def mutate_conn(self, state, key, attrs: Array):
weight = NormalGene._mutate_float(key, attrs[0], state.weight_init_mean, state.weight_init_std,
state.weight_mutate_power, state.weight_mutate_rate,
state.weight_replace_rate)
return jnp.array([weight])
def distance_node(self, state, node1: Array, node2: Array):
# bias + response + activation + aggregation
return jnp.abs(node1[1] - node2[1]) + jnp.abs(node1[2] - node2[2]) + \
(node1[3] != node2[3]) + (node1[4] != node2[4])
def distance_conn(self, state, con1: Array, con2: Array):
return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight
def forward_transform(self, state: State, genome: Genome):
u_conns = unflatten_conns(genome.nodes, genome.conns)
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
# remove enable attr
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
seqs = topological_sort(genome.nodes, conn_enable)
return seqs, genome.nodes, u_conns
def forward(self, state: State, inputs, transformed):
cal_seqs, nodes, cons = transformed
input_idx = state.input_idx
output_idx = state.output_idx
N = nodes.shape[0]
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
weights = cons[0, :]
def cond_fun(carry):
values, idx = carry
return (idx < N) & (cal_seqs[idx] != I_INT)
def body_func(carry):
values, idx = carry
i = cal_seqs[idx]
def hit():
ins = values * weights[:, i]
z = agg(nodes[i, 4], ins, self.agg_funcs) # z = agg(ins)
z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
z = act(nodes[i, 3], z, self.act_funcs) # z = act(z)
new_values = values.at[i].set(z)
return new_values
def miss():
return values
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
return values, idx + 1
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
return vals[output_idx]
@staticmethod
def _mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate):
k1, k2, k3 = jax.random.split(key, num=3)
noise = jax.random.normal(k1, ()) * mutate_power
replace = jax.random.normal(k2, ()) * init_std + init_mean
r = jax.random.uniform(k3, ())
val = jnp.where(
r < mutate_rate,
val + noise,
jnp.where(
(mutate_rate < r) & (r < mutate_rate + replace_rate),
replace,
val
)
)
return val
@staticmethod
def _mutate_int(key, val, options, replace_rate):
k1, k2 = jax.random.split(key, num=2)
r = jax.random.uniform(k1, ())
val = jnp.where(
r < replace_rate,
jax.random.choice(k2, options),
val
)
return val