92 lines
2.6 KiB
Python
92 lines
2.6 KiB
Python
import jax.numpy as jnp
|
|
|
|
import evox
|
|
from algorithms import neat
|
|
from configs import Configer
|
|
|
|
|
|
@evox.jit_class
|
|
class NEAT(evox.Algorithm):
|
|
def __init__(self, config):
|
|
self.config = config # global config
|
|
self.jit_config = Configer.create_jit_config(config)
|
|
(
|
|
self.randkey,
|
|
self.pop_nodes,
|
|
self.pop_cons,
|
|
self.species_info,
|
|
self.idx2species,
|
|
self.center_nodes,
|
|
self.center_cons,
|
|
self.generation,
|
|
self.next_node_key,
|
|
self.next_species_key,
|
|
) = neat.initialize(config)
|
|
super().__init__()
|
|
|
|
def setup(self, key):
|
|
return evox.State(
|
|
randkey=self.randkey,
|
|
pop_nodes=self.pop_nodes,
|
|
pop_cons=self.pop_cons,
|
|
species_info=self.species_info,
|
|
idx2species=self.idx2species,
|
|
center_nodes=self.center_nodes,
|
|
center_cons=self.center_cons,
|
|
generation=self.generation,
|
|
next_node_key=self.next_node_key,
|
|
next_species_key=self.next_species_key,
|
|
jit_config=self.jit_config
|
|
)
|
|
|
|
def ask(self, state):
|
|
flatten_pop_nodes = state.pop_nodes.flatten()
|
|
flatten_pop_cons = state.pop_cons.flatten()
|
|
pop = jnp.concatenate([flatten_pop_nodes, flatten_pop_cons])
|
|
return pop, state
|
|
|
|
def tell(self, state, fitness):
|
|
|
|
# evox is a minimization framework, so we need to negate the fitness
|
|
fitness = -fitness
|
|
|
|
(
|
|
randkey,
|
|
pop_nodes,
|
|
pop_cons,
|
|
species_info,
|
|
idx2species,
|
|
center_nodes,
|
|
center_cons,
|
|
generation,
|
|
next_node_key,
|
|
next_species_key
|
|
) = neat.tell(
|
|
fitness,
|
|
state.randkey,
|
|
state.pop_nodes,
|
|
state.pop_cons,
|
|
state.species_info,
|
|
state.idx2species,
|
|
state.center_nodes,
|
|
state.center_cons,
|
|
state.generation,
|
|
state.next_node_key,
|
|
state.next_species_key,
|
|
state.jit_config
|
|
)
|
|
|
|
return evox.State(
|
|
randkey=randkey,
|
|
pop_nodes=pop_nodes,
|
|
pop_cons=pop_cons,
|
|
species_info=species_info,
|
|
idx2species=idx2species,
|
|
center_nodes=center_nodes,
|
|
center_cons=center_cons,
|
|
generation=generation,
|
|
next_node_key=next_node_key,
|
|
next_species_key=next_species_key,
|
|
jit_config=state.jit_config
|
|
)
|