hyper neat

This commit is contained in:
wls2002
2023-07-24 19:25:02 +08:00
parent ac295c1921
commit ebad574431
24 changed files with 542 additions and 103 deletions

View File

@@ -7,7 +7,7 @@ import numpy as np
from config import Config
from core import Algorithm, State, Gene, Genome
from .ga import crossover, create_mutate
from .species import update_species, create_speciate
from .species import SpeciesInfo, update_species, create_speciate
class NEAT(Algorithm):
@@ -22,9 +22,9 @@ class NEAT(Algorithm):
def setup(self, randkey, state: State = State()):
"""initialize the state of the algorithm"""
input_idx = np.arange(self.config.basic.num_inputs)
output_idx = np.arange(self.config.basic.num_inputs,
self.config.basic.num_inputs + self.config.basic.num_outputs)
input_idx = np.arange(self.config.neat.inputs)
output_idx = np.arange(self.config.neat.inputs,
self.config.neat.inputs + self.config.neat.outputs)
state = state.update(
P=self.config.basic.pop_size,
@@ -49,22 +49,13 @@ class NEAT(Algorithm):
state = self.gene_type.setup(self.config.gene, state)
pop_genomes = self._initialize_genomes(state)
species_keys = np.full((state.S,), np.nan, dtype=np.float32)
best_fitness = np.full((state.S,), np.nan, dtype=np.float32)
last_improved = np.full((state.S,), np.nan, dtype=np.float32)
member_count = np.full((state.S,), np.nan, dtype=np.float32)
species_info = SpeciesInfo.initialize(state)
idx2species = jnp.zeros(state.P, dtype=jnp.float32)
species_keys[0] = 0
best_fitness[0] = -np.inf
last_improved[0] = 0
member_count[0] = state.P
center_nodes = jnp.full((state.S, state.N, state.NL), jnp.nan, dtype=jnp.float32)
center_conns = jnp.full((state.S, state.C, state.CL), jnp.nan, dtype=jnp.float32)
center_nodes = center_nodes.at[0, :, :].set(pop_genomes.nodes[0, :, :])
center_conns = center_conns.at[0, :, :].set(pop_genomes.conns[0, :, :])
center_genomes = vmap(Genome)(center_nodes, center_conns)
center_genomes = Genome(center_nodes, center_conns)
center_genomes = center_genomes.set(0, pop_genomes[0])
generation = 0
next_node_key = max(*state.input_idx, *state.output_idx) + 2
@@ -73,10 +64,7 @@ class NEAT(Algorithm):
state = state.update(
randkey=randkey,
pop_genomes=pop_genomes,
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
member_count=member_count,
species_info=species_info,
idx2species=idx2species,
center_genomes=center_genomes,
@@ -135,7 +123,7 @@ class NEAT(Algorithm):
pop_nodes = np.tile(o_nodes, (state.P, 1, 1))
pop_conns = np.tile(o_conns, (state.P, 1, 1))
return vmap(Genome)(pop_nodes, pop_conns)
return Genome(pop_nodes, pop_conns)
def _create_tell(self):
mutate = create_mutate(self.config.neat, self.gene_type)