change a lot
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import jax
|
||||
|
||||
from algorithm.state import State
|
||||
from .gene import *
|
||||
from .genome import initialize_genomes
|
||||
from .genome import initialize_genomes, create_mutate, create_distance, crossover
|
||||
|
||||
|
||||
class NEAT:
|
||||
@@ -11,6 +13,10 @@ class NEAT:
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.mutate = jax.jit(create_mutate(config, self.gene_type))
|
||||
self.distance = jax.jit(create_distance(config, self.gene_type))
|
||||
self.crossover = jax.jit(crossover)
|
||||
|
||||
def setup(self, randkey):
|
||||
|
||||
state = State(
|
||||
@@ -25,6 +31,8 @@ class NEAT:
|
||||
output_idx=self.config['output_idx']
|
||||
)
|
||||
|
||||
state = self.gene_type.setup(state, self.config)
|
||||
|
||||
pop_nodes, pop_conns = initialize_genomes(state, self.gene_type)
|
||||
next_node_key = max(*state.input_idx, *state.output_idx) + 2
|
||||
state = state.update(
|
||||
|
||||
Reference in New Issue
Block a user