from dataclasses import dataclass from typing import Tuple import jax from jax import Array, numpy as jnp from config import GeneConfig from core import Gene, Genome, State from utils import Activation, Aggregation, unflatten_conns, topological_sort, I_INT, act, agg @dataclass(frozen=True) class NormalGeneConfig(GeneConfig): bias_init_mean: float = 0.0 bias_init_std: float = 1.0 bias_mutate_power: float = 0.5 bias_mutate_rate: float = 0.7 bias_replace_rate: float = 0.1 response_init_mean: float = 1.0 response_init_std: float = 0.0 response_mutate_power: float = 0.5 response_mutate_rate: float = 0.7 response_replace_rate: float = 0.1 activation_default: str = 'sigmoid' activation_options: Tuple[str] = ('sigmoid',) activation_replace_rate: float = 0.1 aggregation_default: str = 'sum' aggregation_options: Tuple[str] = ('sum',) aggregation_replace_rate: float = 0.1 weight_init_mean: float = 0.0 weight_init_std: float = 1.0 weight_mutate_power: float = 0.5 weight_mutate_rate: float = 0.8 weight_replace_rate: float = 0.1 def __post_init__(self): assert self.bias_init_std >= 0.0 assert self.bias_mutate_power >= 0.0 assert self.bias_mutate_rate >= 0.0 assert self.bias_replace_rate >= 0.0 assert self.response_init_std >= 0.0 assert self.response_mutate_power >= 0.0 assert self.response_mutate_rate >= 0.0 assert self.response_replace_rate >= 0.0 assert self.activation_default == self.activation_options[0] for name in self.activation_options: assert name in Activation.name2func, f"Activation function: {name} not found" assert self.aggregation_default == self.aggregation_options[0] assert self.aggregation_default in Aggregation.name2func, \ f"Aggregation function: {self.aggregation_default} not found" for name in self.aggregation_options: assert name in Aggregation.name2func, f"Aggregation function: {name} not found" class NormalGene(Gene): node_attrs = ['bias', 'response', 'aggregation', 'activation'] conn_attrs = ['weight'] def __init__(self, config: NormalGeneConfig): self.config = config self.act_funcs = [Activation.name2func[name] for name in config.activation_options] self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options] def setup(self, state: State = State()): return state.update( bias_init_mean=self.config.bias_init_mean, bias_init_std=self.config.bias_init_std, bias_mutate_power=self.config.bias_mutate_power, bias_mutate_rate=self.config.bias_mutate_rate, bias_replace_rate=self.config.bias_replace_rate, response_init_mean=self.config.response_init_mean, response_init_std=self.config.response_init_std, response_mutate_power=self.config.response_mutate_power, response_mutate_rate=self.config.response_mutate_rate, response_replace_rate=self.config.response_replace_rate, activation_replace_rate=self.config.activation_replace_rate, activation_default=0, activation_options=jnp.arange(len(self.config.activation_options)), aggregation_replace_rate=self.config.aggregation_replace_rate, aggregation_default=0, aggregation_options=jnp.arange(len(self.config.aggregation_options)), weight_init_mean=self.config.weight_init_mean, weight_init_std=self.config.weight_init_std, weight_mutate_power=self.config.weight_mutate_power, weight_mutate_rate=self.config.weight_mutate_rate, weight_replace_rate=self.config.weight_replace_rate, ) def update(self, state): pass def new_node_attrs(self, state): return jnp.array([state.bias_init_mean, state.response_init_mean, state.activation_default, state.aggregation_default]) def new_conn_attrs(self, state): return jnp.array([state.weight_init_mean]) def mutate_node(self, state, key, attrs: Array): k1, k2, k3, k4 = jax.random.split(key, num=4) bias = NormalGene._mutate_float(k1, attrs[0], state.bias_init_mean, state.bias_init_std, state.bias_mutate_power, state.bias_mutate_rate, state.bias_replace_rate) res = NormalGene._mutate_float(k2, attrs[1], state.response_init_mean, state.response_init_std, state.response_mutate_power, state.response_mutate_rate, state.response_replace_rate) act = NormalGene._mutate_int(k3, attrs[2], state.activation_options, state.activation_replace_rate) agg = NormalGene._mutate_int(k4, attrs[3], state.aggregation_options, state.aggregation_replace_rate) return jnp.array([bias, res, act, agg]) def mutate_conn(self, state, key, attrs: Array): weight = NormalGene._mutate_float(key, attrs[0], state.weight_init_mean, state.weight_init_std, state.weight_mutate_power, state.weight_mutate_rate, state.weight_replace_rate) return jnp.array([weight]) def distance_node(self, state, node1: Array, node2: Array): # bias + response + activation + aggregation return jnp.abs(node1[1] - node2[1]) + jnp.abs(node1[2] - node2[2]) + \ (node1[3] != node2[3]) + (node1[4] != node2[4]) def distance_conn(self, state, con1: Array, con2: Array): return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight def forward_transform(self, state: State, genome: Genome): u_conns = unflatten_conns(genome.nodes, genome.conns) conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False) # remove enable attr u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan) seqs = topological_sort(genome.nodes, conn_enable) return seqs, genome.nodes, u_conns def forward(self, state: State, inputs, transformed): cal_seqs, nodes, cons = transformed input_idx = state.input_idx output_idx = state.output_idx N = nodes.shape[0] ini_vals = jnp.full((N,), jnp.nan) ini_vals = ini_vals.at[input_idx].set(inputs) weights = cons[0, :] def cond_fun(carry): values, idx = carry return (idx < N) & (cal_seqs[idx] != I_INT) def body_func(carry): values, idx = carry i = cal_seqs[idx] def hit(): ins = values * weights[:, i] z = agg(nodes[i, 4], ins, self.agg_funcs) # z = agg(ins) z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias z = act(nodes[i, 3], z, self.act_funcs) # z = act(z) new_values = values.at[i].set(z) return new_values def miss(): return values # the val of input nodes is obtained by the task, not by calculation values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit) return values, idx + 1 vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0)) return vals[output_idx] @staticmethod def _mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate): k1, k2, k3 = jax.random.split(key, num=3) noise = jax.random.normal(k1, ()) * mutate_power replace = jax.random.normal(k2, ()) * init_std + init_mean r = jax.random.uniform(k3, ()) val = jnp.where( r < mutate_rate, val + noise, jnp.where( (mutate_rate < r) & (r < mutate_rate + replace_rate), replace, val ) ) return val @staticmethod def _mutate_int(key, val, options, replace_rate): k1, k2 = jax.random.split(key, num=2) r = jax.random.uniform(k1, ()) val = jnp.where( r < replace_rate, jax.random.choice(k2, options), val ) return val