Using Evox to deal with RL tasks! With distributed Gym environment!

Three simple tasks in Gym[classical] are tested.
This commit is contained in:
wls2002
2023-07-04 15:44:08 +08:00
parent c4d34e877b
commit 7bf46575f4
18 changed files with 547 additions and 43 deletions

View File

@@ -27,28 +27,23 @@ class Pipeline:
self.evaluate_time = 0
self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \
self.center_cons, self.generation, self.next_node_key, self.next_species_key = neat.initialize(config)
(
self.randkey,
self.pop_nodes,
self.pop_cons,
self.species_info,
self.idx2species,
self.center_nodes,
self.center_cons,
self.generation,
self.next_node_key,
self.next_species_key,
) = neat.initialize(config)
self.forward = neat.create_forward_function(config)
self.pop_unflatten_connections = jit(vmap(neat.unflatten_connections))
self.pop_topological_sort = jit(vmap(neat.topological_sort))
# self.tell_func = neat.tell.lower(np.zeros(config['pop_size'], dtype=np.float32),
# self.randkey,
# self.pop_nodes,
# self.pop_cons,
# self.species_info,
# self.idx2species,
# self.center_nodes,
# self.center_cons,
# self.generation,
# self.next_node_key,
# self.next_species_key,
# self.jit_config).compile()
def ask(self):
"""
Creates a function that receives a genome and returns a forward function.
@@ -77,21 +72,31 @@ class Pipeline:
return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons)
def tell(self, fitness):
self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \
self.center_cons, self.generation, self.next_node_key, self.next_species_key = neat.tell(fitness,
self.randkey,
self.pop_nodes,
self.pop_cons,
self.species_info,
self.idx2species,
self.center_nodes,
self.center_cons,
self.generation,
self.next_node_key,
self.next_species_key,
self.jit_config)
(
self.randkey,
self.pop_nodes,
self.pop_cons,
self.species_info,
self.idx2species,
self.center_nodes,
self.center_cons,
self.generation,
self.next_node_key,
self.next_species_key,
) = neat.tell(
fitness,
self.randkey,
self.pop_nodes,
self.pop_cons,
self.species_info,
self.idx2species,
self.center_nodes,
self.center_cons,
self.generation,
self.next_node_key,
self.next_species_key,
self.jit_config
)
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config['generation_limit']):