import jax from algorithm.state import State from .gene import * from .genome import initialize_genomes, create_mutate, create_distance, crossover class NEAT: def __init__(self, config): self.config = config if self.config['gene_type'] == 'normal': self.gene_type = NormalGene else: raise NotImplementedError self.mutate = jax.jit(create_mutate(config, self.gene_type)) self.distance = jax.jit(create_distance(config, self.gene_type)) self.crossover = jax.jit(crossover) def setup(self, randkey): state = State( randkey=randkey, P=self.config['pop_size'], N=self.config['maximum_nodes'], C=self.config['maximum_connections'], S=self.config['maximum_species'], NL=1 + len(self.gene_type.node_attrs), # node length = (key) + attributes CL=3 + len(self.gene_type.conn_attrs), # conn length = (in, out, key) + attributes input_idx=self.config['input_idx'], output_idx=self.config['output_idx'] ) state = self.gene_type.setup(state, self.config) pop_nodes, pop_conns = initialize_genomes(state, self.gene_type) next_node_key = max(*state.input_idx, *state.output_idx) + 2 state = state.update( pop_nodes=pop_nodes, pop_conns=pop_conns, next_node_key=next_node_key ) return state def tell(self, state, fitness): return State() def ask(self, state): return State()