finish all refactoring
This commit is contained in:
2
algorithm/hyperneat/__init__.py
Normal file
2
algorithm/hyperneat/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .hyperneat import HyperNEAT
|
||||
from .substrate import BaseSubstrate, DefaultSubstrate, FullSubstrate
|
||||
116
algorithm/hyperneat/hyperneat.py
Normal file
116
algorithm/hyperneat/hyperneat.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from utils import State, Act, Agg
|
||||
from .. import BaseAlgorithm, NEAT
|
||||
from ..neat.gene import BaseNodeGene, BaseConnGene
|
||||
from ..neat.genome import RecurrentGenome
|
||||
from .substrate import *
|
||||
|
||||
|
||||
class HyperNEAT(BaseAlgorithm):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
substrate: BaseSubstrate,
|
||||
neat: NEAT,
|
||||
below_threshold: float = 0.3,
|
||||
max_weight: float = 5.,
|
||||
activation=Act.sigmoid,
|
||||
aggregation=Agg.sum,
|
||||
activate_time: int = 10,
|
||||
):
|
||||
assert substrate.query_coors.shape[1] == neat.num_inputs, \
|
||||
"Substrate input size should be equal to NEAT input size"
|
||||
|
||||
self.substrate = substrate
|
||||
self.neat = neat
|
||||
self.below_threshold = below_threshold
|
||||
self.max_weight = max_weight
|
||||
self.hyper_genome = RecurrentGenome(
|
||||
num_inputs=substrate.num_inputs,
|
||||
num_outputs=substrate.num_outputs,
|
||||
max_nodes=substrate.nodes_cnt,
|
||||
max_conns=substrate.conns_cnt,
|
||||
node_gene=HyperNodeGene(activation, aggregation),
|
||||
conn_gene=HyperNEATConnGene(),
|
||||
activate_time=activate_time,
|
||||
)
|
||||
|
||||
def setup(self, randkey):
|
||||
return State(
|
||||
neat_state=self.neat.setup(randkey)
|
||||
)
|
||||
|
||||
def ask(self, state: State):
|
||||
return self.neat.ask(state.neat_state)
|
||||
|
||||
def tell(self, state: State, fitness):
|
||||
return state.update(
|
||||
neat_state=self.neat.tell(state.neat_state, fitness)
|
||||
)
|
||||
|
||||
def transform(self, individual):
|
||||
transformed = self.neat.transform(individual)
|
||||
query_res = jax.vmap(self.neat.forward, in_axes=(0, None))(self.substrate.query_coors, transformed)
|
||||
|
||||
# mute the connection with weight below threshold
|
||||
query_res = jnp.where(
|
||||
(-self.below_threshold < query_res) & (query_res < self.below_threshold),
|
||||
0.,
|
||||
query_res
|
||||
)
|
||||
|
||||
# make query res in range [-max_weight, max_weight]
|
||||
query_res = jnp.where(query_res > 0, query_res - self.below_threshold, query_res)
|
||||
query_res = jnp.where(query_res < 0, query_res + self.below_threshold, query_res)
|
||||
query_res = query_res / (1 - self.below_threshold) * self.max_weight
|
||||
|
||||
h_nodes, h_conns = self.substrate.make_nodes(query_res), self.substrate.make_conn(query_res)
|
||||
return self.hyper_genome.transform(h_nodes, h_conns)
|
||||
|
||||
def forward(self, inputs, transformed):
|
||||
# add bias
|
||||
inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])])
|
||||
return self.hyper_genome.forward(inputs_with_bias, transformed)
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
return self.substrate.num_inputs - 1 # remove bias
|
||||
|
||||
@property
|
||||
def num_outputs(self):
|
||||
return self.substrate.num_outputs
|
||||
|
||||
@property
|
||||
def pop_size(self):
|
||||
return self.neat.pop_size
|
||||
|
||||
def member_count(self, state: State):
|
||||
return self.neat.member_count(state.neat_state)
|
||||
|
||||
def generation(self, state: State):
|
||||
return self.neat.generation(state.neat_state)
|
||||
|
||||
|
||||
class HyperNodeGene(BaseNodeGene):
|
||||
|
||||
def __init__(self,
|
||||
activation=Act.sigmoid,
|
||||
aggregation=Agg.sum,
|
||||
):
|
||||
super().__init__()
|
||||
self.activation = activation
|
||||
self.aggregation = aggregation
|
||||
|
||||
def forward(self, attrs, inputs):
|
||||
return self.activation(
|
||||
self.aggregation(inputs)
|
||||
)
|
||||
|
||||
|
||||
class HyperNEATConnGene(BaseConnGene):
|
||||
custom_attrs = ['weight']
|
||||
|
||||
def forward(self, attrs, inputs):
|
||||
weight = attrs[0]
|
||||
return inputs * weight
|
||||
3
algorithm/hyperneat/substrate/__init__.py
Normal file
3
algorithm/hyperneat/substrate/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base import BaseSubstrate
|
||||
from .default import DefaultSubstrate
|
||||
from .full import FullSubstrate
|
||||
27
algorithm/hyperneat/substrate/base.py
Normal file
27
algorithm/hyperneat/substrate/base.py
Normal file
@@ -0,0 +1,27 @@
|
||||
class BaseSubstrate:
|
||||
|
||||
def make_nodes(self, query_res):
|
||||
raise NotImplementedError
|
||||
|
||||
def make_conn(self, query_res):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def query_coors(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def num_outputs(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def nodes_cnt(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def conns_cnt(self):
|
||||
raise NotImplementedError
|
||||
38
algorithm/hyperneat/substrate/default.py
Normal file
38
algorithm/hyperneat/substrate/default.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import jax.numpy as jnp
|
||||
from . import BaseSubstrate
|
||||
|
||||
|
||||
class DefaultSubstrate(BaseSubstrate):
|
||||
|
||||
def __init__(self, num_inputs, num_outputs, coors, nodes, conns):
|
||||
self.inputs = num_inputs
|
||||
self.outputs = num_outputs
|
||||
self.coors = jnp.array(coors)
|
||||
self.nodes = jnp.array(nodes)
|
||||
self.conns = jnp.array(conns)
|
||||
|
||||
def make_nodes(self, query_res):
|
||||
return self.nodes
|
||||
|
||||
def make_conn(self, query_res):
|
||||
return self.conns.at[:, 3:].set(query_res) # change weight
|
||||
|
||||
@property
|
||||
def query_coors(self):
|
||||
return self.coors
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
return self.inputs
|
||||
|
||||
@property
|
||||
def num_outputs(self):
|
||||
return self.outputs
|
||||
|
||||
@property
|
||||
def nodes_cnt(self):
|
||||
return self.nodes.shape[0]
|
||||
|
||||
@property
|
||||
def conns_cnt(self):
|
||||
return self.conns.shape[0]
|
||||
76
algorithm/hyperneat/substrate/full.py
Normal file
76
algorithm/hyperneat/substrate/full.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import numpy as np
|
||||
from .default import DefaultSubstrate
|
||||
|
||||
|
||||
class FullSubstrate(DefaultSubstrate):
|
||||
|
||||
def __init__(self,
|
||||
input_coors=((-1, -1), (0, -1), (1, -1)),
|
||||
hidden_coors=((-1, 0), (0, 0), (1, 0)),
|
||||
output_coors=((0, 1),),
|
||||
):
|
||||
query_coors, nodes, conns = analysis_substrate(input_coors, output_coors, hidden_coors)
|
||||
super().__init__(
|
||||
len(input_coors),
|
||||
len(output_coors),
|
||||
query_coors,
|
||||
nodes,
|
||||
conns
|
||||
)
|
||||
|
||||
|
||||
def analysis_substrate(input_coors, output_coors, hidden_coors):
|
||||
input_coors = np.array(input_coors)
|
||||
output_coors = np.array(output_coors)
|
||||
hidden_coors = np.array(hidden_coors)
|
||||
|
||||
cd = input_coors.shape[1] # coordinate dimensions
|
||||
si = input_coors.shape[0] # input coordinate size
|
||||
so = output_coors.shape[0] # output coordinate size
|
||||
sh = hidden_coors.shape[0] # hidden coordinate size
|
||||
|
||||
input_idx = np.arange(si)
|
||||
output_idx = np.arange(si, si + so)
|
||||
hidden_idx = np.arange(si + so, si + so + sh)
|
||||
|
||||
total_conns = si * sh + sh * sh + sh * so
|
||||
query_coors = np.zeros((total_conns, cd * 2))
|
||||
correspond_keys = np.zeros((total_conns, 2))
|
||||
|
||||
# connect input to hidden
|
||||
aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, input_coors, hidden_coors)
|
||||
query_coors[0: si * sh, :] = aux_coors
|
||||
correspond_keys[0: si * sh, :] = aux_keys
|
||||
|
||||
# connect hidden to hidden
|
||||
aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, hidden_coors, hidden_coors)
|
||||
query_coors[si * sh: si * sh + sh * sh, :] = aux_coors
|
||||
correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys
|
||||
|
||||
# connect hidden to output
|
||||
aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, hidden_coors, output_coors)
|
||||
query_coors[si * sh + sh * sh:, :] = aux_coors
|
||||
correspond_keys[si * sh + sh * sh:, :] = aux_keys
|
||||
|
||||
nodes = np.concatenate((input_idx, output_idx, hidden_idx))[..., np.newaxis]
|
||||
conns = np.zeros((correspond_keys.shape[0], 4), dtype=np.float32) # input_idx, output_idx, enabled, weight
|
||||
conns[:, 0:2] = correspond_keys
|
||||
conns[:, 2] = 1 # enabled is True
|
||||
|
||||
return query_coors, nodes, conns
|
||||
|
||||
|
||||
def cartesian_product(keys1, keys2, coors1, coors2):
|
||||
len1 = keys1.shape[0]
|
||||
len2 = keys2.shape[0]
|
||||
|
||||
repeated_coors1 = np.repeat(coors1, len2, axis=0)
|
||||
repeated_keys1 = np.repeat(keys1, len2)
|
||||
|
||||
tiled_coors2 = np.tile(coors2, (len1, 1))
|
||||
tiled_keys2 = np.tile(keys2, len1)
|
||||
|
||||
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
|
||||
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
|
||||
|
||||
return new_coors, correspond_keys
|
||||
Reference in New Issue
Block a user