hyper neat

This commit is contained in:
wls2002
2023-07-24 19:25:02 +08:00
parent ac295c1921
commit ebad574431
24 changed files with 542 additions and 103 deletions

View File

@@ -0,0 +1,2 @@
from .hyper_neat import HyperNEAT
from .substrate import NormalSubstrate, NormalSubstrateConfig

View File

@@ -0,0 +1,122 @@
from typing import Type
import jax
from jax import numpy as jnp, Array, vmap
import numpy as np
from config import Config, HyperNeatConfig
from core import Algorithm, Substrate, State, Genome
from utils import Activation, Aggregation
from algorithm.neat import NEAT
from .substrate import analysis_substrate
class HyperNEAT(Algorithm):
def __init__(self, config: Config, neat: NEAT, substrate: Type[Substrate]):
self.config = config
self.neat = neat
self.substrate = substrate
self.forward_func = None
def setup(self, randkey, state=State()):
neat_key, randkey = jax.random.split(randkey)
state = state.update(
below_threshold=self.config.hyper_neat.below_threshold,
max_weight=self.config.hyper_neat.max_weight,
)
state = self.neat.setup(neat_key, state)
state = self.substrate.setup(self.config.substrate, state)
assert self.config.hyper_neat.inputs + 1 == state.input_coors.shape[0] # +1 for bias
assert self.config.hyper_neat.outputs == state.output_coors.shape[0]
h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state)
h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis]
h_conns = np.zeros((correspond_keys.shape[0], 3), dtype=np.float32)
h_conns[:, 0:2] = correspond_keys
state = state.update(
h_input_idx=h_input_idx,
h_output_idx=h_output_idx,
h_hidden_idx=h_hidden_idx,
h_nodes=h_nodes,
h_conns=h_conns,
query_coors=query_coors,
)
self.forward_func = HyperNEATGene.create_forward(self.config.hyper_neat, state)
return state
def ask(self, state: State):
return state.pop_genomes
def tell(self, state: State, fitness):
return self.neat.tell(state, fitness)
def forward(self, inputs: Array, transformed: Array):
return self.forward_func(inputs, transformed)
def forward_transform(self, state: State, genome: Genome):
t = self.neat.forward_transform(state, genome)
query_res = vmap(self.neat.forward, in_axes=(0, None))(state.query_coors, t)
# mute the connection with weight below threshold
query_res = jnp.where((-state.below_threshold < query_res) & (query_res < state.below_threshold), 0., query_res)
# make query res in range [-max_weight, max_weight]
query_res = jnp.where(query_res > 0, query_res - state.below_threshold, query_res)
query_res = jnp.where(query_res < 0, query_res + state.below_threshold, query_res)
query_res = query_res / (1 - state.below_threshold) * state.max_weight
h_conns = state.h_conns.at[:, 2:].set(query_res)
return HyperNEATGene.forward_transform(Genome(state.h_nodes, h_conns))
class HyperNEATGene:
node_attrs = [] # no node attributes
conn_attrs = ['weight']
@staticmethod
def forward_transform(genome: Genome):
N = genome.nodes.shape[0]
u_conns = jnp.zeros((N, N), dtype=jnp.float32)
in_keys = jnp.asarray(genome.conns[:, 0], jnp.int32)
out_keys = jnp.asarray(genome.conns[:, 1], jnp.int32)
weights = genome.conns[:, 2]
u_conns = u_conns.at[in_keys, out_keys].set(weights)
return genome.nodes, u_conns
@staticmethod
def create_forward(config: HyperNeatConfig, state: State):
act = Activation.name2func[config.activation]
agg = Aggregation.name2func[config.aggregation]
batch_act, batch_agg = jax.vmap(act), jax.vmap(agg)
def forward(inputs, transform):
inputs_with_bias = jnp.concatenate((inputs, jnp.ones((1,))), axis=0)
nodes, weights = transform
input_idx = state.h_input_idx
output_idx = state.h_output_idx
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
def body_func(i, values):
values = values.at[input_idx].set(inputs_with_bias)
nodes_ins = values * weights.T
values = batch_agg(nodes_ins) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(values) # z = act(z)
return values
vals = jax.lax.fori_loop(0, config.activate_times, body_func, vals)
return vals[output_idx]
return forward

View File

@@ -0,0 +1,2 @@
from .normal import NormalSubstrate, NormalSubstrateConfig
from .tools import analysis_substrate

View File

@@ -0,0 +1,25 @@
from dataclasses import dataclass
from typing import Tuple
import numpy as np
from core import Substrate, State
from config import SubstrateConfig
@dataclass(frozen=True)
class NormalSubstrateConfig(SubstrateConfig):
input_coors: Tuple[Tuple[float]] = ((-1, -1), (0, -1), (1, -1))
hidden_coors: Tuple[Tuple[float]] = ((-1, 0), (0, 0), (1, 0))
output_coors: Tuple[Tuple[float]] = ((0, 1), )
class NormalSubstrate(Substrate):
@staticmethod
def setup(config: NormalSubstrateConfig, state: State = State()):
return state.update(
input_coors=np.asarray(config.input_coors, dtype=np.float32),
output_coors=np.asarray(config.output_coors, dtype=np.float32),
hidden_coors=np.asarray(config.hidden_coors, dtype=np.float32),
)

View File

@@ -0,0 +1,50 @@
from typing import Type
import numpy as np
def analysis_substrate(state):
cd = state.input_coors.shape[1] # coordinate dimensions
si = state.input_coors.shape[0] # input coordinate size
so = state.output_coors.shape[0] # output coordinate size
sh = state.hidden_coors.shape[0] # hidden coordinate size
input_idx = np.arange(si)
output_idx = np.arange(si, si + so)
hidden_idx = np.arange(si + so, si + so + sh)
total_conns = si * sh + sh * sh + sh * so
query_coors = np.zeros((total_conns, cd * 2))
correspond_keys = np.zeros((total_conns, 2))
# connect input to hidden
aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, state.input_coors, state.hidden_coors)
query_coors[0: si * sh, :] = aux_coors
correspond_keys[0: si * sh, :] = aux_keys
# connect hidden to hidden
aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, state.hidden_coors, state.hidden_coors)
query_coors[si * sh: si * sh + sh * sh, :] = aux_coors
correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys
# connect hidden to output
aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, state.hidden_coors, state.output_coors)
query_coors[si * sh + sh * sh:, :] = aux_coors
correspond_keys[si * sh + sh * sh:, :] = aux_keys
return input_idx, output_idx, hidden_idx, query_coors, correspond_keys
def cartesian_product(keys1, keys2, coors1, coors2):
len1 = keys1.shape[0]
len2 = keys2.shape[0]
repeated_coors1 = np.repeat(coors1, len2, axis=0)
repeated_keys1 = np.repeat(keys1, len2)
tiled_coors2 = np.tile(coors2, (len1, 1))
tiled_keys2 = np.tile(keys2, len1)
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
return new_coors, correspond_keys