hyper neat
This commit is contained in:
@@ -7,7 +7,7 @@ import numpy as np
|
||||
from config import Config
|
||||
from core import Algorithm, State, Gene, Genome
|
||||
from .ga import crossover, create_mutate
|
||||
from .species import update_species, create_speciate
|
||||
from .species import SpeciesInfo, update_species, create_speciate
|
||||
|
||||
|
||||
class NEAT(Algorithm):
|
||||
@@ -22,9 +22,9 @@ class NEAT(Algorithm):
|
||||
def setup(self, randkey, state: State = State()):
|
||||
"""initialize the state of the algorithm"""
|
||||
|
||||
input_idx = np.arange(self.config.basic.num_inputs)
|
||||
output_idx = np.arange(self.config.basic.num_inputs,
|
||||
self.config.basic.num_inputs + self.config.basic.num_outputs)
|
||||
input_idx = np.arange(self.config.neat.inputs)
|
||||
output_idx = np.arange(self.config.neat.inputs,
|
||||
self.config.neat.inputs + self.config.neat.outputs)
|
||||
|
||||
state = state.update(
|
||||
P=self.config.basic.pop_size,
|
||||
@@ -49,22 +49,13 @@ class NEAT(Algorithm):
|
||||
state = self.gene_type.setup(self.config.gene, state)
|
||||
pop_genomes = self._initialize_genomes(state)
|
||||
|
||||
species_keys = np.full((state.S,), np.nan, dtype=np.float32)
|
||||
best_fitness = np.full((state.S,), np.nan, dtype=np.float32)
|
||||
last_improved = np.full((state.S,), np.nan, dtype=np.float32)
|
||||
member_count = np.full((state.S,), np.nan, dtype=np.float32)
|
||||
species_info = SpeciesInfo.initialize(state)
|
||||
idx2species = jnp.zeros(state.P, dtype=jnp.float32)
|
||||
|
||||
species_keys[0] = 0
|
||||
best_fitness[0] = -np.inf
|
||||
last_improved[0] = 0
|
||||
member_count[0] = state.P
|
||||
|
||||
center_nodes = jnp.full((state.S, state.N, state.NL), jnp.nan, dtype=jnp.float32)
|
||||
center_conns = jnp.full((state.S, state.C, state.CL), jnp.nan, dtype=jnp.float32)
|
||||
center_nodes = center_nodes.at[0, :, :].set(pop_genomes.nodes[0, :, :])
|
||||
center_conns = center_conns.at[0, :, :].set(pop_genomes.conns[0, :, :])
|
||||
center_genomes = vmap(Genome)(center_nodes, center_conns)
|
||||
center_genomes = Genome(center_nodes, center_conns)
|
||||
center_genomes = center_genomes.set(0, pop_genomes[0])
|
||||
|
||||
generation = 0
|
||||
next_node_key = max(*state.input_idx, *state.output_idx) + 2
|
||||
@@ -73,10 +64,7 @@ class NEAT(Algorithm):
|
||||
state = state.update(
|
||||
randkey=randkey,
|
||||
pop_genomes=pop_genomes,
|
||||
species_keys=species_keys,
|
||||
best_fitness=best_fitness,
|
||||
last_improved=last_improved,
|
||||
member_count=member_count,
|
||||
species_info=species_info,
|
||||
idx2species=idx2species,
|
||||
center_genomes=center_genomes,
|
||||
|
||||
@@ -135,7 +123,7 @@ class NEAT(Algorithm):
|
||||
pop_nodes = np.tile(o_nodes, (state.P, 1, 1))
|
||||
pop_conns = np.tile(o_conns, (state.P, 1, 1))
|
||||
|
||||
return vmap(Genome)(pop_nodes, pop_conns)
|
||||
return Genome(pop_nodes, pop_conns)
|
||||
|
||||
def _create_tell(self):
|
||||
mutate = create_mutate(self.config.neat, self.gene_type)
|
||||
|
||||
Reference in New Issue
Block a user