remove create_func....

This commit is contained in:
wls2002
2023-08-04 17:29:36 +08:00
parent c7fb1ddabe
commit 0e44b13291
29 changed files with 591 additions and 259 deletions

View File

@@ -5,30 +5,30 @@ from jax import numpy as jnp, Array, vmap
import numpy as np
from config import Config, HyperNeatConfig
from core import Algorithm, Substrate, State, Genome
from core import Algorithm, Substrate, State, Genome, Gene
from utils import Activation, Aggregation
from algorithm.neat import NEAT
from .substrate import analysis_substrate
from algorithm import NEAT
class HyperNEAT(Algorithm):
def __init__(self, config: Config, neat: NEAT, substrate: Type[Substrate]):
def __init__(self, config: Config, gene: Type[Gene], substrate: Type[Substrate]):
self.config = config
self.neat = neat
self.neat = NEAT(config, gene)
self.substrate = substrate
def setup(self, randkey, state=State()):
neat_key, randkey = jax.random.split(randkey)
state = state.update(
below_threshold=self.config.hyper_neat.below_threshold,
max_weight=self.config.hyper_neat.max_weight,
below_threshold=self.config.hyperneat.below_threshold,
max_weight=self.config.hyperneat.max_weight,
)
state = self.neat.setup(neat_key, state)
state = self.substrate.setup(self.config.substrate, state)
assert self.config.hyper_neat.inputs + 1 == state.input_coors.shape[0] # +1 for bias
assert self.config.hyper_neat.outputs == state.output_coors.shape[0]
assert self.config.hyperneat.inputs + 1 == state.input_coors.shape[0] # +1 for bias
assert self.config.hyperneat.outputs == state.output_coors.shape[0]
h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state)
h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis]
@@ -53,7 +53,7 @@ class HyperNEAT(Algorithm):
return self.neat.tell(state, fitness)
def forward(self, state, inputs: Array, transformed: Array):
return HyperNEATGene.forward(self.config.hyper_neat, state, inputs, transformed)
return HyperNEATGene.forward(self.config.hyperneat, state, inputs, transformed)
def forward_transform(self, state: State, genome: Genome):
t = self.neat.forward_transform(state, genome)
@@ -68,6 +68,7 @@ class HyperNEAT(Algorithm):
query_res = query_res / (1 - state.below_threshold) * state.max_weight
h_conns = state.h_conns.at[:, 2:].set(query_res)
return HyperNEATGene.forward_transform(Genome(state.h_nodes, h_conns))