Files
tensorneat-mend/evox_adaptor/neat.py
wls2002 7bf46575f4 Using Evox to deal with RL tasks! With distributed Gym environment!
Three simple tasks in Gym[classical] are tested.
2023-07-04 15:44:08 +08:00

92 lines
2.6 KiB
Python

import jax.numpy as jnp
import evox
from algorithms import neat
from configs import Configer
@evox.jit_class
class NEAT(evox.Algorithm):
def __init__(self, config):
self.config = config # global config
self.jit_config = Configer.create_jit_config(config)
(
self.randkey,
self.pop_nodes,
self.pop_cons,
self.species_info,
self.idx2species,
self.center_nodes,
self.center_cons,
self.generation,
self.next_node_key,
self.next_species_key,
) = neat.initialize(config)
super().__init__()
def setup(self, key):
return evox.State(
randkey=self.randkey,
pop_nodes=self.pop_nodes,
pop_cons=self.pop_cons,
species_info=self.species_info,
idx2species=self.idx2species,
center_nodes=self.center_nodes,
center_cons=self.center_cons,
generation=self.generation,
next_node_key=self.next_node_key,
next_species_key=self.next_species_key,
jit_config=self.jit_config
)
def ask(self, state):
flatten_pop_nodes = state.pop_nodes.flatten()
flatten_pop_cons = state.pop_cons.flatten()
pop = jnp.concatenate([flatten_pop_nodes, flatten_pop_cons])
return pop, state
def tell(self, state, fitness):
# evox is a minimization framework, so we need to negate the fitness
fitness = -fitness
(
randkey,
pop_nodes,
pop_cons,
species_info,
idx2species,
center_nodes,
center_cons,
generation,
next_node_key,
next_species_key
) = neat.tell(
fitness,
state.randkey,
state.pop_nodes,
state.pop_cons,
state.species_info,
state.idx2species,
state.center_nodes,
state.center_cons,
state.generation,
state.next_node_key,
state.next_species_key,
state.jit_config
)
return evox.State(
randkey=randkey,
pop_nodes=pop_nodes,
pop_cons=pop_cons,
species_info=species_info,
idx2species=idx2species,
center_nodes=center_nodes,
center_cons=center_cons,
generation=generation,
next_node_key=next_node_key,
next_species_key=next_species_key,
jit_config=state.jit_config
)