remove create_func....

This commit is contained in:
wls2002
2023-08-04 17:29:36 +08:00
parent c7fb1ddabe
commit 0e44b13291
29 changed files with 591 additions and 259 deletions

View File

@@ -5,30 +5,30 @@ from jax import numpy as jnp, Array, vmap
import numpy as np
from config import Config, HyperNeatConfig
from core import Algorithm, Substrate, State, Genome
from core import Algorithm, Substrate, State, Genome, Gene
from utils import Activation, Aggregation
from algorithm.neat import NEAT
from .substrate import analysis_substrate
from algorithm import NEAT
class HyperNEAT(Algorithm):
def __init__(self, config: Config, neat: NEAT, substrate: Type[Substrate]):
def __init__(self, config: Config, gene: Type[Gene], substrate: Type[Substrate]):
self.config = config
self.neat = neat
self.neat = NEAT(config, gene)
self.substrate = substrate
def setup(self, randkey, state=State()):
neat_key, randkey = jax.random.split(randkey)
state = state.update(
below_threshold=self.config.hyper_neat.below_threshold,
max_weight=self.config.hyper_neat.max_weight,
below_threshold=self.config.hyperneat.below_threshold,
max_weight=self.config.hyperneat.max_weight,
)
state = self.neat.setup(neat_key, state)
state = self.substrate.setup(self.config.substrate, state)
assert self.config.hyper_neat.inputs + 1 == state.input_coors.shape[0] # +1 for bias
assert self.config.hyper_neat.outputs == state.output_coors.shape[0]
assert self.config.hyperneat.inputs + 1 == state.input_coors.shape[0] # +1 for bias
assert self.config.hyperneat.outputs == state.output_coors.shape[0]
h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state)
h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis]
@@ -53,7 +53,7 @@ class HyperNEAT(Algorithm):
return self.neat.tell(state, fitness)
def forward(self, state, inputs: Array, transformed: Array):
return HyperNEATGene.forward(self.config.hyper_neat, state, inputs, transformed)
return HyperNEATGene.forward(self.config.hyperneat, state, inputs, transformed)
def forward_transform(self, state: State, genome: Genome):
t = self.neat.forward_transform(state, genome)
@@ -68,6 +68,7 @@ class HyperNEAT(Algorithm):
query_res = query_res / (1 - state.below_threshold) * state.max_weight
h_conns = state.h_conns.at[:, 2:].set(query_res)
return HyperNEATGene.forward_transform(Genome(state.h_nodes, h_conns))

View File

@@ -9,9 +9,9 @@ from config import SubstrateConfig
@dataclass(frozen=True)
class NormalSubstrateConfig(SubstrateConfig):
input_coors: Tuple[Tuple[float]] = ((-1, -1), (0, -1), (1, -1))
hidden_coors: Tuple[Tuple[float]] = ((-1, 0), (0, 0), (1, 0))
output_coors: Tuple[Tuple[float]] = ((0, 1),)
input_coors: Tuple = ((-1, -1), (0, -1), (1, -1))
hidden_coors: Tuple = ((-1, 0), (0, 0), (1, 0))
output_coors: Tuple = ((0, 1),)
class NormalSubstrate(Substrate):

View File

@@ -1 +1,2 @@
from .neat import NEAT
from .gene import *

View File

@@ -66,7 +66,7 @@ class NormalGene(Gene):
node_attrs = ['bias', 'response', 'aggregation', 'activation']
conn_attrs = ['weight']
def __init__(self, config: NormalGeneConfig):
def __init__(self, config: NormalGeneConfig = NormalGeneConfig()):
self.config = config
self.act_funcs = [Activation.name2func[name] for name in config.activation_options]
self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
@@ -101,7 +101,7 @@ class NormalGene(Gene):
)
def update(self, state):
pass
return state
def new_node_attrs(self, state):
return jnp.array([state.bias_init_mean, state.response_init_mean,

View File

@@ -19,7 +19,7 @@ class RecurrentGeneConfig(NormalGeneConfig):
class RecurrentGene(NormalGene):
def __init__(self, config: RecurrentGeneConfig):
def __init__(self, config: RecurrentGeneConfig = RecurrentGeneConfig()):
self.config = config
super().__init__(config)

View File

@@ -28,9 +28,9 @@ class NEAT(Algorithm):
state = state.update(
P=self.config.basic.pop_size,
N=self.config.neat.maximum_nodes,
C=self.config.neat.maximum_conns,
S=self.config.neat.maximum_species,
N=self.config.neat.max_nodes,
C=self.config.neat.max_conns,
S=self.config.neat.max_species,
NL=1 + len(self.gene.node_attrs), # node length = (key) + attributes
CL=3 + len(self.gene.conn_attrs), # conn length = (in, out, key) + attributes
max_stagnation=self.config.neat.max_stagnation,
@@ -80,6 +80,8 @@ class NEAT(Algorithm):
return state.pop_genomes
def tell_algorithm(self, state: State, fitness):
state = self.gene.update(state)
k1, k2, randkey = jax.random.split(state.randkey, 3)
state = state.update(