change a lot

This commit is contained in:
wls2002
2023-07-17 19:59:46 +08:00
parent f4763ebcea
commit 40cf0b6fbe
8 changed files with 248 additions and 18 deletions

View File

@@ -1,6 +1,8 @@
import jax
from algorithm.state import State
from .gene import *
from .genome import initialize_genomes
from .genome import initialize_genomes, create_mutate, create_distance, crossover
class NEAT:
@@ -11,6 +13,10 @@ class NEAT:
else:
raise NotImplementedError
self.mutate = jax.jit(create_mutate(config, self.gene_type))
self.distance = jax.jit(create_distance(config, self.gene_type))
self.crossover = jax.jit(crossover)
def setup(self, randkey):
state = State(
@@ -25,6 +31,8 @@ class NEAT:
output_idx=self.config['output_idx']
)
state = self.gene_type.setup(state, self.config)
pop_nodes, pop_conns = initialize_genomes(state, self.gene_type)
next_node_key = max(*state.input_idx, *state.output_idx) + 2
state = state.update(