hyper neat
This commit is contained in:
2
algorithm/hyper_neat/__init__.py
Normal file
2
algorithm/hyper_neat/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .hyper_neat import HyperNEAT
|
||||
from .substrate import NormalSubstrate, NormalSubstrateConfig
|
||||
122
algorithm/hyper_neat/hyper_neat.py
Normal file
122
algorithm/hyper_neat/hyper_neat.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from typing import Type
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp, Array, vmap
|
||||
import numpy as np
|
||||
|
||||
from config import Config, HyperNeatConfig
|
||||
from core import Algorithm, Substrate, State, Genome
|
||||
from utils import Activation, Aggregation
|
||||
from algorithm.neat import NEAT
|
||||
from .substrate import analysis_substrate
|
||||
|
||||
class HyperNEAT(Algorithm):
|
||||
|
||||
def __init__(self, config: Config, neat: NEAT, substrate: Type[Substrate]):
|
||||
self.config = config
|
||||
self.neat = neat
|
||||
self.substrate = substrate
|
||||
|
||||
self.forward_func = None
|
||||
|
||||
def setup(self, randkey, state=State()):
|
||||
neat_key, randkey = jax.random.split(randkey)
|
||||
state = state.update(
|
||||
below_threshold=self.config.hyper_neat.below_threshold,
|
||||
max_weight=self.config.hyper_neat.max_weight,
|
||||
)
|
||||
state = self.neat.setup(neat_key, state)
|
||||
state = self.substrate.setup(self.config.substrate, state)
|
||||
|
||||
assert self.config.hyper_neat.inputs + 1 == state.input_coors.shape[0] # +1 for bias
|
||||
assert self.config.hyper_neat.outputs == state.output_coors.shape[0]
|
||||
|
||||
h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state)
|
||||
h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis]
|
||||
h_conns = np.zeros((correspond_keys.shape[0], 3), dtype=np.float32)
|
||||
h_conns[:, 0:2] = correspond_keys
|
||||
|
||||
state = state.update(
|
||||
h_input_idx=h_input_idx,
|
||||
h_output_idx=h_output_idx,
|
||||
h_hidden_idx=h_hidden_idx,
|
||||
h_nodes=h_nodes,
|
||||
h_conns=h_conns,
|
||||
query_coors=query_coors,
|
||||
)
|
||||
|
||||
self.forward_func = HyperNEATGene.create_forward(self.config.hyper_neat, state)
|
||||
|
||||
return state
|
||||
def ask(self, state: State):
|
||||
return state.pop_genomes
|
||||
|
||||
def tell(self, state: State, fitness):
|
||||
return self.neat.tell(state, fitness)
|
||||
|
||||
def forward(self, inputs: Array, transformed: Array):
|
||||
return self.forward_func(inputs, transformed)
|
||||
|
||||
def forward_transform(self, state: State, genome: Genome):
|
||||
t = self.neat.forward_transform(state, genome)
|
||||
query_res = vmap(self.neat.forward, in_axes=(0, None))(state.query_coors, t)
|
||||
|
||||
# mute the connection with weight below threshold
|
||||
query_res = jnp.where((-state.below_threshold < query_res) & (query_res < state.below_threshold), 0., query_res)
|
||||
|
||||
# make query res in range [-max_weight, max_weight]
|
||||
query_res = jnp.where(query_res > 0, query_res - state.below_threshold, query_res)
|
||||
query_res = jnp.where(query_res < 0, query_res + state.below_threshold, query_res)
|
||||
query_res = query_res / (1 - state.below_threshold) * state.max_weight
|
||||
|
||||
h_conns = state.h_conns.at[:, 2:].set(query_res)
|
||||
return HyperNEATGene.forward_transform(Genome(state.h_nodes, h_conns))
|
||||
|
||||
|
||||
class HyperNEATGene:
|
||||
node_attrs = [] # no node attributes
|
||||
conn_attrs = ['weight']
|
||||
|
||||
@staticmethod
|
||||
def forward_transform(genome: Genome):
|
||||
N = genome.nodes.shape[0]
|
||||
u_conns = jnp.zeros((N, N), dtype=jnp.float32)
|
||||
|
||||
in_keys = jnp.asarray(genome.conns[:, 0], jnp.int32)
|
||||
out_keys = jnp.asarray(genome.conns[:, 1], jnp.int32)
|
||||
weights = genome.conns[:, 2]
|
||||
|
||||
u_conns = u_conns.at[in_keys, out_keys].set(weights)
|
||||
return genome.nodes, u_conns
|
||||
|
||||
@staticmethod
|
||||
def create_forward(config: HyperNeatConfig, state: State):
|
||||
|
||||
act = Activation.name2func[config.activation]
|
||||
agg = Aggregation.name2func[config.aggregation]
|
||||
|
||||
batch_act, batch_agg = jax.vmap(act), jax.vmap(agg)
|
||||
|
||||
def forward(inputs, transform):
|
||||
|
||||
inputs_with_bias = jnp.concatenate((inputs, jnp.ones((1,))), axis=0)
|
||||
nodes, weights = transform
|
||||
|
||||
input_idx = state.h_input_idx
|
||||
output_idx = state.h_output_idx
|
||||
|
||||
N = nodes.shape[0]
|
||||
vals = jnp.full((N,), 0.)
|
||||
|
||||
def body_func(i, values):
|
||||
values = values.at[input_idx].set(inputs_with_bias)
|
||||
nodes_ins = values * weights.T
|
||||
values = batch_agg(nodes_ins) # z = agg(ins)
|
||||
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
|
||||
values = batch_act(values) # z = act(z)
|
||||
return values
|
||||
|
||||
vals = jax.lax.fori_loop(0, config.activate_times, body_func, vals)
|
||||
return vals[output_idx]
|
||||
|
||||
return forward
|
||||
2
algorithm/hyper_neat/substrate/__init__.py
Normal file
2
algorithm/hyper_neat/substrate/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .normal import NormalSubstrate, NormalSubstrateConfig
|
||||
from .tools import analysis_substrate
|
||||
25
algorithm/hyper_neat/substrate/normal.py
Normal file
25
algorithm/hyper_neat/substrate/normal.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from core import Substrate, State
|
||||
from config import SubstrateConfig
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NormalSubstrateConfig(SubstrateConfig):
|
||||
input_coors: Tuple[Tuple[float]] = ((-1, -1), (0, -1), (1, -1))
|
||||
hidden_coors: Tuple[Tuple[float]] = ((-1, 0), (0, 0), (1, 0))
|
||||
output_coors: Tuple[Tuple[float]] = ((0, 1), )
|
||||
|
||||
|
||||
class NormalSubstrate(Substrate):
|
||||
|
||||
@staticmethod
|
||||
def setup(config: NormalSubstrateConfig, state: State = State()):
|
||||
return state.update(
|
||||
input_coors=np.asarray(config.input_coors, dtype=np.float32),
|
||||
output_coors=np.asarray(config.output_coors, dtype=np.float32),
|
||||
hidden_coors=np.asarray(config.hidden_coors, dtype=np.float32),
|
||||
)
|
||||
50
algorithm/hyper_neat/substrate/tools.py
Normal file
50
algorithm/hyper_neat/substrate/tools.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import Type
|
||||
|
||||
import numpy as np
|
||||
|
||||
def analysis_substrate(state):
|
||||
cd = state.input_coors.shape[1] # coordinate dimensions
|
||||
si = state.input_coors.shape[0] # input coordinate size
|
||||
so = state.output_coors.shape[0] # output coordinate size
|
||||
sh = state.hidden_coors.shape[0] # hidden coordinate size
|
||||
|
||||
input_idx = np.arange(si)
|
||||
output_idx = np.arange(si, si + so)
|
||||
hidden_idx = np.arange(si + so, si + so + sh)
|
||||
|
||||
total_conns = si * sh + sh * sh + sh * so
|
||||
query_coors = np.zeros((total_conns, cd * 2))
|
||||
correspond_keys = np.zeros((total_conns, 2))
|
||||
|
||||
# connect input to hidden
|
||||
aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, state.input_coors, state.hidden_coors)
|
||||
query_coors[0: si * sh, :] = aux_coors
|
||||
correspond_keys[0: si * sh, :] = aux_keys
|
||||
|
||||
# connect hidden to hidden
|
||||
aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, state.hidden_coors, state.hidden_coors)
|
||||
query_coors[si * sh: si * sh + sh * sh, :] = aux_coors
|
||||
correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys
|
||||
|
||||
# connect hidden to output
|
||||
aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, state.hidden_coors, state.output_coors)
|
||||
query_coors[si * sh + sh * sh:, :] = aux_coors
|
||||
correspond_keys[si * sh + sh * sh:, :] = aux_keys
|
||||
|
||||
return input_idx, output_idx, hidden_idx, query_coors, correspond_keys
|
||||
|
||||
|
||||
def cartesian_product(keys1, keys2, coors1, coors2):
|
||||
len1 = keys1.shape[0]
|
||||
len2 = keys2.shape[0]
|
||||
|
||||
repeated_coors1 = np.repeat(coors1, len2, axis=0)
|
||||
repeated_keys1 = np.repeat(keys1, len2)
|
||||
|
||||
tiled_coors2 = np.tile(coors2, (len1, 1))
|
||||
tiled_keys2 = np.tile(keys2, len1)
|
||||
|
||||
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
|
||||
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
|
||||
|
||||
return new_coors, correspond_keys
|
||||
Reference in New Issue
Block a user