new architecture
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
from .normal import NormalGene, NormalGeneConfig
|
||||
from .recurrent import RecurrentGene, RecurrentGeneConfig
|
||||
|
||||
from .base import BaseGene
|
||||
from .conn import *
|
||||
from .node import *
|
||||
|
||||
23
algorithm/neat/gene/base.py
Normal file
23
algorithm/neat/gene/base.py
Normal file
@@ -0,0 +1,23 @@
|
||||
class BaseGene:
|
||||
"Base class for node genes or connection genes."
|
||||
fixed_attrs = []
|
||||
custom_attrs = []
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def new_custom_attrs(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def mutate(self, randkey, gene):
|
||||
raise NotImplementedError
|
||||
|
||||
def distance(self, gene1, gene2):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, attrs, inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def length(self):
|
||||
return len(self.fixed_attrs) + len(self.custom_attrs)
|
||||
2
algorithm/neat/gene/conn/__init__.py
Normal file
2
algorithm/neat/gene/conn/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .base import BaseConnGene
|
||||
from .default import DefaultConnGene
|
||||
12
algorithm/neat/gene/conn/base.py
Normal file
12
algorithm/neat/gene/conn/base.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from .. import BaseGene
|
||||
|
||||
|
||||
class BaseConnGene(BaseGene):
|
||||
"Base class for connection genes."
|
||||
fixed_attrs = ['input_index', 'output_index', 'enabled']
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, attrs, inputs):
|
||||
raise NotImplementedError
|
||||
51
algorithm/neat/gene/conn/default.py
Normal file
51
algorithm/neat/gene/conn/default.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from utils import mutate_float
|
||||
from . import BaseConnGene
|
||||
|
||||
|
||||
class DefaultConnGene(BaseConnGene):
|
||||
"Default connection gene, with the same behavior as in NEAT-python."
|
||||
|
||||
fixed_attrs = ['input_index', 'output_index', 'enabled']
|
||||
attrs = ['weight']
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_init_mean: float = 0.0,
|
||||
weight_init_std: float = 1.0,
|
||||
weight_mutate_power: float = 0.5,
|
||||
weight_mutate_rate: float = 0.8,
|
||||
weight_replace_rate: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
self.weight_init_mean = weight_init_mean
|
||||
self.weight_init_std = weight_init_std
|
||||
self.weight_mutate_power = weight_mutate_power
|
||||
self.weight_mutate_rate = weight_mutate_rate
|
||||
self.weight_replace_rate = weight_replace_rate
|
||||
|
||||
def new_custom_attrs(self):
|
||||
return jnp.array([self.weight_init_mean])
|
||||
|
||||
def mutate(self, key, conn):
|
||||
input_index = conn[0]
|
||||
output_index = conn[1]
|
||||
enabled = conn[2]
|
||||
weight = mutate_float(key,
|
||||
conn[3],
|
||||
self.weight_init_mean,
|
||||
self.weight_init_std,
|
||||
self.weight_mutate_power,
|
||||
self.weight_mutate_rate,
|
||||
self.weight_replace_rate
|
||||
)
|
||||
|
||||
return jnp.array([input_index, output_index, enabled, weight])
|
||||
|
||||
def distance(self, attrs1, attrs2):
|
||||
return (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
|
||||
|
||||
def forward(self, attrs, inputs):
|
||||
weight = attrs[0]
|
||||
return inputs * weight
|
||||
2
algorithm/neat/gene/node/__init__.py
Normal file
2
algorithm/neat/gene/node/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .base import BaseNodeGene
|
||||
from .default import DefaultNodeGene
|
||||
12
algorithm/neat/gene/node/base.py
Normal file
12
algorithm/neat/gene/node/base.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from .. import BaseGene
|
||||
|
||||
|
||||
class BaseNodeGene(BaseGene):
|
||||
"Base class for node genes."
|
||||
fixed_attrs = ["index"]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, attrs, inputs):
|
||||
raise NotImplementedError
|
||||
96
algorithm/neat/gene/node/default.py
Normal file
96
algorithm/neat/gene/node/default.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from typing import Tuple
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from utils import Act, Agg, act, agg, mutate_int, mutate_float
|
||||
from . import BaseNodeGene
|
||||
|
||||
|
||||
class DefaultNodeGene(BaseNodeGene):
|
||||
"Default node gene, with the same behavior as in NEAT-python."
|
||||
|
||||
fixed_attrs = ['index']
|
||||
custom_attrs = ['bias', 'response', 'aggregation', 'activation']
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bias_init_mean: float = 0.0,
|
||||
bias_init_std: float = 1.0,
|
||||
bias_mutate_power: float = 0.5,
|
||||
bias_mutate_rate: float = 0.7,
|
||||
bias_replace_rate: float = 0.1,
|
||||
|
||||
response_init_mean: float = 1.0,
|
||||
response_init_std: float = 0.0,
|
||||
response_mutate_power: float = 0.5,
|
||||
response_mutate_rate: float = 0.7,
|
||||
response_replace_rate: float = 0.1,
|
||||
|
||||
activation_default: callable = Act.sigmoid,
|
||||
activation_options: Tuple = (Act.sigmoid,),
|
||||
activation_replace_rate: float = 0.1,
|
||||
|
||||
aggregation_default: callable = Agg.sum,
|
||||
aggregation_options: Tuple = (Agg.sum,),
|
||||
aggregation_replace_rate: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
self.bias_init_mean = bias_init_mean
|
||||
self.bias_init_std = bias_init_std
|
||||
self.bias_mutate_power = bias_mutate_power
|
||||
self.bias_mutate_rate = bias_mutate_rate
|
||||
self.bias_replace_rate = bias_replace_rate
|
||||
|
||||
self.response_init_mean = response_init_mean
|
||||
self.response_init_std = response_init_std
|
||||
self.response_mutate_power = response_mutate_power
|
||||
self.response_mutate_rate = response_mutate_rate
|
||||
self.response_replace_rate = response_replace_rate
|
||||
|
||||
self.activation_default = activation_options.index(activation_default)
|
||||
self.activation_options = activation_options
|
||||
self.activation_indices = jnp.arange(len(activation_options))
|
||||
self.activation_replace_rate = activation_replace_rate
|
||||
|
||||
self.aggregation_default = aggregation_options.index(aggregation_default)
|
||||
self.aggregation_options = aggregation_options
|
||||
self.aggregation_indices = jnp.arange(len(aggregation_options))
|
||||
self.aggregation_replace_rate = aggregation_replace_rate
|
||||
|
||||
def new_custom_attrs(self):
|
||||
return jnp.array(
|
||||
[self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default]
|
||||
)
|
||||
|
||||
def mutate(self, key, node):
|
||||
k1, k2, k3, k4 = jax.random.split(key, num=4)
|
||||
index = node[0]
|
||||
|
||||
bias = mutate_float(k1, node[1], self.bias_init_mean, self.bias_init_std,
|
||||
self.bias_mutate_power, self.bias_mutate_rate, self.bias_replace_rate)
|
||||
|
||||
res = mutate_float(k2, node[2], self.response_init_mean, self.response_init_std,
|
||||
self.response_mutate_power, self.response_mutate_rate, self.response_replace_rate)
|
||||
|
||||
act = mutate_int(k3, node[3], self.activation_indices, self.activation_replace_rate)
|
||||
|
||||
agg = mutate_int(k4, node[4], self.aggregation_indices, self.aggregation_replace_rate)
|
||||
|
||||
return jnp.array([index, bias, res, act, agg])
|
||||
|
||||
def distance(self, node1, node2):
|
||||
return (
|
||||
jnp.abs(node1[1] - node2[1]) +
|
||||
jnp.abs(node1[2] - node2[2]) +
|
||||
node1[3] != node2[3] +
|
||||
node1[4] != node2[4]
|
||||
)
|
||||
|
||||
def forward(self, attrs, inputs):
|
||||
bias, res, act_idx, agg_idx = attrs
|
||||
|
||||
z = agg(agg_idx, inputs, self.aggregation_options)
|
||||
z = bias + res * z
|
||||
z = act(act_idx, z, self.activation_options)
|
||||
|
||||
return z
|
||||
@@ -1,210 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import jax
|
||||
from jax import Array, numpy as jnp
|
||||
|
||||
from config import GeneConfig
|
||||
from core import Gene, Genome, State
|
||||
from utils import Act, Agg, unflatten_conns, topological_sort, I_INT, act, agg
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NormalGeneConfig(GeneConfig):
|
||||
bias_init_mean: float = 0.0
|
||||
bias_init_std: float = 1.0
|
||||
bias_mutate_power: float = 0.5
|
||||
bias_mutate_rate: float = 0.7
|
||||
bias_replace_rate: float = 0.1
|
||||
|
||||
response_init_mean: float = 1.0
|
||||
response_init_std: float = 0.0
|
||||
response_mutate_power: float = 0.5
|
||||
response_mutate_rate: float = 0.7
|
||||
response_replace_rate: float = 0.1
|
||||
|
||||
activation_default: callable = Act.sigmoid
|
||||
activation_options: Tuple = (Act.sigmoid, )
|
||||
activation_replace_rate: float = 0.1
|
||||
|
||||
aggregation_default: callable = Agg.sum
|
||||
aggregation_options: Tuple = (Agg.sum, )
|
||||
aggregation_replace_rate: float = 0.1
|
||||
|
||||
weight_init_mean: float = 0.0
|
||||
weight_init_std: float = 1.0
|
||||
weight_mutate_power: float = 0.5
|
||||
weight_mutate_rate: float = 0.8
|
||||
weight_replace_rate: float = 0.1
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.bias_init_std >= 0.0
|
||||
assert self.bias_mutate_power >= 0.0
|
||||
assert self.bias_mutate_rate >= 0.0
|
||||
assert self.bias_replace_rate >= 0.0
|
||||
|
||||
assert self.response_init_std >= 0.0
|
||||
assert self.response_mutate_power >= 0.0
|
||||
assert self.response_mutate_rate >= 0.0
|
||||
assert self.response_replace_rate >= 0.0
|
||||
|
||||
assert self.activation_default == self.activation_options[0]
|
||||
assert self.aggregation_default == self.aggregation_options[0]
|
||||
|
||||
|
||||
class NormalGene(Gene):
|
||||
node_attrs = ['bias', 'response', 'aggregation', 'activation']
|
||||
conn_attrs = ['weight']
|
||||
|
||||
def __init__(self, config: NormalGeneConfig = NormalGeneConfig()):
|
||||
self.config = config
|
||||
|
||||
def setup(self, state: State = State()):
|
||||
return state.update(
|
||||
bias_init_mean=self.config.bias_init_mean,
|
||||
bias_init_std=self.config.bias_init_std,
|
||||
bias_mutate_power=self.config.bias_mutate_power,
|
||||
bias_mutate_rate=self.config.bias_mutate_rate,
|
||||
bias_replace_rate=self.config.bias_replace_rate,
|
||||
|
||||
response_init_mean=self.config.response_init_mean,
|
||||
response_init_std=self.config.response_init_std,
|
||||
response_mutate_power=self.config.response_mutate_power,
|
||||
response_mutate_rate=self.config.response_mutate_rate,
|
||||
response_replace_rate=self.config.response_replace_rate,
|
||||
|
||||
activation_replace_rate=self.config.activation_replace_rate,
|
||||
activation_default=0,
|
||||
activation_options=jnp.arange(len(self.config.activation_options)),
|
||||
|
||||
aggregation_replace_rate=self.config.aggregation_replace_rate,
|
||||
aggregation_default=0,
|
||||
aggregation_options=jnp.arange(len(self.config.aggregation_options)),
|
||||
|
||||
weight_init_mean=self.config.weight_init_mean,
|
||||
weight_init_std=self.config.weight_init_std,
|
||||
weight_mutate_power=self.config.weight_mutate_power,
|
||||
weight_mutate_rate=self.config.weight_mutate_rate,
|
||||
weight_replace_rate=self.config.weight_replace_rate,
|
||||
)
|
||||
|
||||
def update(self, state):
|
||||
return state
|
||||
|
||||
def new_node_attrs(self, state):
|
||||
return jnp.array([state.bias_init_mean, state.response_init_mean,
|
||||
state.activation_default, state.aggregation_default])
|
||||
|
||||
def new_conn_attrs(self, state):
|
||||
return jnp.array([state.weight_init_mean])
|
||||
|
||||
def mutate_node(self, state, key, attrs: Array):
|
||||
k1, k2, k3, k4 = jax.random.split(key, num=4)
|
||||
|
||||
bias = NormalGene._mutate_float(k1, attrs[0], state.bias_init_mean, state.bias_init_std,
|
||||
state.bias_mutate_power, state.bias_mutate_rate, state.bias_replace_rate)
|
||||
res = NormalGene._mutate_float(k2, attrs[1], state.response_init_mean, state.response_init_std,
|
||||
state.response_mutate_power, state.response_mutate_rate,
|
||||
state.response_replace_rate)
|
||||
act = NormalGene._mutate_int(k3, attrs[2], state.activation_options, state.activation_replace_rate)
|
||||
agg = NormalGene._mutate_int(k4, attrs[3], state.aggregation_options, state.aggregation_replace_rate)
|
||||
|
||||
return jnp.array([bias, res, act, agg])
|
||||
|
||||
def mutate_conn(self, state, key, attrs: Array):
|
||||
weight = NormalGene._mutate_float(key, attrs[0], state.weight_init_mean, state.weight_init_std,
|
||||
state.weight_mutate_power, state.weight_mutate_rate,
|
||||
state.weight_replace_rate)
|
||||
|
||||
return jnp.array([weight])
|
||||
|
||||
def distance_node(self, state, node1: Array, node2: Array):
|
||||
# bias + response + activation + aggregation
|
||||
return jnp.abs(node1[1] - node2[1]) + jnp.abs(node1[2] - node2[2]) + \
|
||||
(node1[3] != node2[3]) + (node1[4] != node2[4])
|
||||
|
||||
def distance_conn(self, state, con1: Array, con2: Array):
|
||||
return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight
|
||||
|
||||
def forward_transform(self, state: State, genome: Genome):
|
||||
u_conns = unflatten_conns(genome.nodes, genome.conns)
|
||||
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
|
||||
|
||||
# remove enable attr
|
||||
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
||||
seqs = topological_sort(genome.nodes, conn_enable)
|
||||
|
||||
return seqs, genome.nodes, u_conns
|
||||
|
||||
def forward(self, state: State, inputs, transformed):
|
||||
cal_seqs, nodes, cons = transformed
|
||||
|
||||
input_idx = state.input_idx
|
||||
output_idx = state.output_idx
|
||||
|
||||
N = nodes.shape[0]
|
||||
ini_vals = jnp.full((N,), jnp.nan)
|
||||
ini_vals = ini_vals.at[input_idx].set(inputs)
|
||||
|
||||
weights = cons[0, :]
|
||||
|
||||
def cond_fun(carry):
|
||||
values, idx = carry
|
||||
return (idx < N) & (cal_seqs[idx] != I_INT)
|
||||
|
||||
def body_func(carry):
|
||||
values, idx = carry
|
||||
i = cal_seqs[idx]
|
||||
|
||||
def hit():
|
||||
ins = values * weights[:, i]
|
||||
z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins)
|
||||
z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
|
||||
z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z)
|
||||
|
||||
new_values = values.at[i].set(z)
|
||||
return new_values
|
||||
|
||||
def miss():
|
||||
return values
|
||||
|
||||
# the val of input nodes is obtained by the task, not by calculation
|
||||
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
|
||||
|
||||
return values, idx + 1
|
||||
|
||||
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
|
||||
|
||||
return vals[output_idx]
|
||||
|
||||
@staticmethod
|
||||
def _mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate):
|
||||
k1, k2, k3 = jax.random.split(key, num=3)
|
||||
noise = jax.random.normal(k1, ()) * mutate_power
|
||||
replace = jax.random.normal(k2, ()) * init_std + init_mean
|
||||
r = jax.random.uniform(k3, ())
|
||||
|
||||
val = jnp.where(
|
||||
r < mutate_rate,
|
||||
val + noise,
|
||||
jnp.where(
|
||||
(mutate_rate < r) & (r < mutate_rate + replace_rate),
|
||||
replace,
|
||||
val
|
||||
)
|
||||
)
|
||||
|
||||
return val
|
||||
|
||||
@staticmethod
|
||||
def _mutate_int(key, val, options, replace_rate):
|
||||
k1, k2 = jax.random.split(key, num=2)
|
||||
r = jax.random.uniform(k1, ())
|
||||
|
||||
val = jnp.where(
|
||||
r < replace_rate,
|
||||
jax.random.choice(k2, options),
|
||||
val
|
||||
)
|
||||
|
||||
return val
|
||||
@@ -1,57 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp, vmap
|
||||
|
||||
from .normal import NormalGene, NormalGeneConfig
|
||||
from core import State, Genome
|
||||
from utils import unflatten_conns, act, agg
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RecurrentGeneConfig(NormalGeneConfig):
|
||||
activate_times: int = 10
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
assert self.activate_times > 0
|
||||
|
||||
|
||||
class RecurrentGene(NormalGene):
|
||||
|
||||
def __init__(self, config: RecurrentGeneConfig = RecurrentGeneConfig()):
|
||||
self.config = config
|
||||
super().__init__(config)
|
||||
|
||||
def forward_transform(self, state: State, genome: Genome):
|
||||
u_conns = unflatten_conns(genome.nodes, genome.conns)
|
||||
|
||||
# remove un-enable connections and remove enable attr
|
||||
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
|
||||
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
||||
|
||||
return genome.nodes, u_conns
|
||||
|
||||
def forward(self, state: State, inputs, transformed):
|
||||
nodes, conns = transformed
|
||||
|
||||
batch_act, batch_agg = vmap(act, in_axes=(0, 0, None)), vmap(agg, in_axes=(0, 0, None))
|
||||
|
||||
input_idx = state.input_idx
|
||||
output_idx = state.output_idx
|
||||
|
||||
N = nodes.shape[0]
|
||||
vals = jnp.full((N,), 0.)
|
||||
|
||||
weights = conns[0, :]
|
||||
|
||||
def body_func(i, values):
|
||||
values = values.at[input_idx].set(inputs)
|
||||
nodes_ins = values * weights.T
|
||||
values = batch_agg(nodes[:, 4], nodes_ins, self.config.aggregation_options) # z = agg(ins)
|
||||
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
|
||||
values = batch_act(nodes[:, 3], values, self.config.activation_options) # z = act(z)
|
||||
return values
|
||||
|
||||
vals = jax.lax.fori_loop(0, self.config.activate_times, body_func, vals)
|
||||
return vals[output_idx]
|
||||
Reference in New Issue
Block a user