remove create_func....
This commit is contained in:
@@ -5,30 +5,30 @@ from jax import numpy as jnp, Array, vmap
|
||||
import numpy as np
|
||||
|
||||
from config import Config, HyperNeatConfig
|
||||
from core import Algorithm, Substrate, State, Genome
|
||||
from core import Algorithm, Substrate, State, Genome, Gene
|
||||
from utils import Activation, Aggregation
|
||||
from algorithm.neat import NEAT
|
||||
from .substrate import analysis_substrate
|
||||
from algorithm import NEAT
|
||||
|
||||
|
||||
class HyperNEAT(Algorithm):
|
||||
|
||||
def __init__(self, config: Config, neat: NEAT, substrate: Type[Substrate]):
|
||||
def __init__(self, config: Config, gene: Type[Gene], substrate: Type[Substrate]):
|
||||
self.config = config
|
||||
self.neat = neat
|
||||
self.neat = NEAT(config, gene)
|
||||
self.substrate = substrate
|
||||
|
||||
def setup(self, randkey, state=State()):
|
||||
neat_key, randkey = jax.random.split(randkey)
|
||||
state = state.update(
|
||||
below_threshold=self.config.hyper_neat.below_threshold,
|
||||
max_weight=self.config.hyper_neat.max_weight,
|
||||
below_threshold=self.config.hyperneat.below_threshold,
|
||||
max_weight=self.config.hyperneat.max_weight,
|
||||
)
|
||||
state = self.neat.setup(neat_key, state)
|
||||
state = self.substrate.setup(self.config.substrate, state)
|
||||
|
||||
assert self.config.hyper_neat.inputs + 1 == state.input_coors.shape[0] # +1 for bias
|
||||
assert self.config.hyper_neat.outputs == state.output_coors.shape[0]
|
||||
assert self.config.hyperneat.inputs + 1 == state.input_coors.shape[0] # +1 for bias
|
||||
assert self.config.hyperneat.outputs == state.output_coors.shape[0]
|
||||
|
||||
h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state)
|
||||
h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis]
|
||||
@@ -53,7 +53,7 @@ class HyperNEAT(Algorithm):
|
||||
return self.neat.tell(state, fitness)
|
||||
|
||||
def forward(self, state, inputs: Array, transformed: Array):
|
||||
return HyperNEATGene.forward(self.config.hyper_neat, state, inputs, transformed)
|
||||
return HyperNEATGene.forward(self.config.hyperneat, state, inputs, transformed)
|
||||
|
||||
def forward_transform(self, state: State, genome: Genome):
|
||||
t = self.neat.forward_transform(state, genome)
|
||||
@@ -68,6 +68,7 @@ class HyperNEAT(Algorithm):
|
||||
query_res = query_res / (1 - state.below_threshold) * state.max_weight
|
||||
|
||||
h_conns = state.h_conns.at[:, 2:].set(query_res)
|
||||
|
||||
return HyperNEATGene.forward_transform(Genome(state.h_nodes, h_conns))
|
||||
|
||||
|
||||
|
||||
@@ -9,9 +9,9 @@ from config import SubstrateConfig
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NormalSubstrateConfig(SubstrateConfig):
|
||||
input_coors: Tuple[Tuple[float]] = ((-1, -1), (0, -1), (1, -1))
|
||||
hidden_coors: Tuple[Tuple[float]] = ((-1, 0), (0, 0), (1, 0))
|
||||
output_coors: Tuple[Tuple[float]] = ((0, 1),)
|
||||
input_coors: Tuple = ((-1, -1), (0, -1), (1, -1))
|
||||
hidden_coors: Tuple = ((-1, 0), (0, 0), (1, 0))
|
||||
output_coors: Tuple = ((0, 1),)
|
||||
|
||||
|
||||
class NormalSubstrate(Substrate):
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
from .neat import NEAT
|
||||
from .gene import *
|
||||
|
||||
@@ -66,7 +66,7 @@ class NormalGene(Gene):
|
||||
node_attrs = ['bias', 'response', 'aggregation', 'activation']
|
||||
conn_attrs = ['weight']
|
||||
|
||||
def __init__(self, config: NormalGeneConfig):
|
||||
def __init__(self, config: NormalGeneConfig = NormalGeneConfig()):
|
||||
self.config = config
|
||||
self.act_funcs = [Activation.name2func[name] for name in config.activation_options]
|
||||
self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
|
||||
@@ -101,7 +101,7 @@ class NormalGene(Gene):
|
||||
)
|
||||
|
||||
def update(self, state):
|
||||
pass
|
||||
return state
|
||||
|
||||
def new_node_attrs(self, state):
|
||||
return jnp.array([state.bias_init_mean, state.response_init_mean,
|
||||
|
||||
@@ -19,7 +19,7 @@ class RecurrentGeneConfig(NormalGeneConfig):
|
||||
|
||||
class RecurrentGene(NormalGene):
|
||||
|
||||
def __init__(self, config: RecurrentGeneConfig):
|
||||
def __init__(self, config: RecurrentGeneConfig = RecurrentGeneConfig()):
|
||||
self.config = config
|
||||
super().__init__(config)
|
||||
|
||||
|
||||
@@ -28,9 +28,9 @@ class NEAT(Algorithm):
|
||||
|
||||
state = state.update(
|
||||
P=self.config.basic.pop_size,
|
||||
N=self.config.neat.maximum_nodes,
|
||||
C=self.config.neat.maximum_conns,
|
||||
S=self.config.neat.maximum_species,
|
||||
N=self.config.neat.max_nodes,
|
||||
C=self.config.neat.max_conns,
|
||||
S=self.config.neat.max_species,
|
||||
NL=1 + len(self.gene.node_attrs), # node length = (key) + attributes
|
||||
CL=3 + len(self.gene.conn_attrs), # conn length = (in, out, key) + attributes
|
||||
max_stagnation=self.config.neat.max_stagnation,
|
||||
@@ -80,6 +80,8 @@ class NEAT(Algorithm):
|
||||
return state.pop_genomes
|
||||
|
||||
def tell_algorithm(self, state: State, fitness):
|
||||
state = self.gene.update(state)
|
||||
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
|
||||
state = state.update(
|
||||
|
||||
Reference in New Issue
Block a user