From 7bf46575f454fd2a802bdc549f85d77a0f395213 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Tue, 4 Jul 2023 15:44:08 +0800 Subject: [PATCH] Using Evox to deal with RL tasks! With distributed Gym environment! Three simple tasks in Gym[classical] are tested. --- algorithms/neat/genome/forward.py | 9 ++- algorithms/neat/genome/graph.py | 2 +- configs/configer.py | 5 -- configs/default_config.ini | 2 +- evox_adaptor/__init__.py | 2 + evox_adaptor/gym_no_distribution.py | 83 ++++++++++++++++++++++++++ evox_adaptor/neat.py | 91 +++++++++++++++++++++++++++++ examples/evox_/__init__.py | 0 examples/evox_/acrobot.ini | 22 +++++++ examples/evox_/acrobot.py | 62 ++++++++++++++++++++ examples/evox_/bipedalwalker.ini | 22 +++++++ examples/evox_/bipedalwalker.py | 62 ++++++++++++++++++++ examples/evox_/cartpole.ini | 11 ++++ examples/evox_/cartpole.py | 62 ++++++++++++++++++++ examples/evox_/mountain_car.ini | 22 +++++++ examples/evox_/mountain_car.py | 62 ++++++++++++++++++++ examples/xor3d.ini | 2 +- pipeline.py | 69 ++++++++++++---------- 18 files changed, 547 insertions(+), 43 deletions(-) create mode 100644 evox_adaptor/__init__.py create mode 100644 evox_adaptor/gym_no_distribution.py create mode 100644 evox_adaptor/neat.py create mode 100644 examples/evox_/__init__.py create mode 100644 examples/evox_/acrobot.ini create mode 100644 examples/evox_/acrobot.py create mode 100644 examples/evox_/bipedalwalker.ini create mode 100644 examples/evox_/bipedalwalker.py create mode 100644 examples/evox_/cartpole.ini create mode 100644 examples/evox_/cartpole.py create mode 100644 examples/evox_/mountain_car.ini create mode 100644 examples/evox_/mountain_car.py diff --git a/algorithms/neat/genome/forward.py b/algorithms/neat/genome/forward.py index a0f26b7..bc37bcb 100644 --- a/algorithms/neat/genome/forward.py +++ b/algorithms/neat/genome/forward.py @@ -2,12 +2,16 @@ import jax from jax import Array, numpy as jnp, jit, vmap from .utils import I_INT +from .activations import act_name2func +from .aggregations import agg_name2func def create_forward_function(config): """ meta method to create forward function """ + config['activation_funcs'] = [act_name2func[name] for name in config['activation_option_names']] + config['aggregation_funcs'] = [agg_name2func[name] for name in config['aggregation_option_names']] def act(idx, z): """ @@ -92,12 +96,11 @@ def create_forward_function(config): common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0)) if config['forward_way'] == 'single': - return jit(batch_forward) + return jit(forward) + # return jit(batch_forward) elif config['forward_way'] == 'pop': return jit(pop_batch_forward) elif config['forward_way'] == 'common': return jit(common_forward) - - return jit(forward) diff --git a/algorithms/neat/genome/graph.py b/algorithms/neat/genome/graph.py index b37a12b..0dda1ee 100644 --- a/algorithms/neat/genome/graph.py +++ b/algorithms/neat/genome/graph.py @@ -1,5 +1,5 @@ """ -Some graph algorithms implemented in jax. +Some graph algorithm implemented in jax. Only used in feed-forward networks. """ diff --git a/configs/configer.py b/configs/configer.py index 79f60b3..4b8946b 100644 --- a/configs/configer.py +++ b/configs/configer.py @@ -4,9 +4,6 @@ import configparser import numpy as np -from algorithms.neat.genome.activations import act_name2func -from algorithms.neat.genome.aggregations import agg_name2func - # Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX. jit_config_keys = [ "input_idx", @@ -108,13 +105,11 @@ class Configer: def refactor_activation(cls, config): config['activation_default'] = 0 config['activation_options'] = np.arange(len(config['activation_option_names'])) - config['activation_funcs'] = [act_name2func[name] for name in config['activation_option_names']] @classmethod def refactor_aggregation(cls, config): config['aggregation_default'] = 0 config['aggregation_options'] = np.arange(len(config['aggregation_option_names'])) - config['aggregation_funcs'] = [agg_name2func[name] for name in config['aggregation_option_names']] @classmethod def create_jit_config(cls, config): diff --git a/configs/default_config.ini b/configs/default_config.ini index 5378872..e8c3be4 100644 --- a/configs/default_config.ini +++ b/configs/default_config.ini @@ -12,7 +12,7 @@ random_seed = 0 fitness_threshold = 3.99999 generation_limit = 1000 fitness_criterion = "max" -pop_size = 100000 +pop_size = 10000 [genome] compatibility_disjoint = 1.0 diff --git a/evox_adaptor/__init__.py b/evox_adaptor/__init__.py new file mode 100644 index 0000000..43d3342 --- /dev/null +++ b/evox_adaptor/__init__.py @@ -0,0 +1,2 @@ +from .neat import NEAT +from .gym_no_distribution import Gym diff --git a/evox_adaptor/gym_no_distribution.py b/evox_adaptor/gym_no_distribution.py new file mode 100644 index 0000000..ad7365f --- /dev/null +++ b/evox_adaptor/gym_no_distribution.py @@ -0,0 +1,83 @@ +from typing import Callable + +import gym +import jax +import jax.numpy as jnp +import numpy as np + +from evox import Problem, State + + +class Gym(Problem): + def __init__( + self, + pop_size: int, + policy: Callable, + env_name: str = "CartPole-v1", + env_options: dict = None, + batch_policy: bool = True, + ): + self.pop_size = pop_size + self.env_name = env_name + self.policy = policy + self.env_options = env_options or {} + self.batch_policy = batch_policy + assert batch_policy, "Only batch policy is supported for now" + + self.envs = [gym.make(env_name, **self.env_options) for _ in range(self.pop_size)] + + super().__init__() + + def setup(self, key): + return State(key=key) + + def evaluate(self, state, pop): + key = state.key + # key, subkey = jax.random.split(state.key) + + # generate a list of seeds for gym + # seeds = jax.random.randint( + # subkey, (self.pop_size,), 0, jnp.iinfo(jnp.int32).max + # ) + + # currently use fixed seed for debugging + seeds = jax.random.randint( + key, (self.pop_size,), 0, jnp.iinfo(jnp.int32).max + ) + + seeds = seeds.tolist() # seed must be a python int, not numpy array + + fitnesses = self.__rollout(seeds, pop) + print("fitnesses info: ") + print(f"max: {np.max(fitnesses)}, min: {np.min(fitnesses)}, mean: {np.mean(fitnesses)}, std: {np.std(fitnesses)}") + + # evox uses negative fitness for minimization + return -fitnesses, State(key=key) + + def __rollout(self, seeds, pop): + observations, infos = zip( + *[env.reset(seed=seed) for env, seed in zip(self.envs, seeds)] + ) + terminates, truncates = np.zeros((2, self.pop_size), dtype=bool) + fitnesses, rewards = np.zeros((2, self.pop_size)) + + while not np.all(terminates | truncates): + observations = np.asarray(observations) + actions = self.policy(pop, observations) + actions = jax.device_get(actions) + + for i, (action, terminate, truncate, env) in enumerate(zip(actions, terminates, truncates, self.envs)): + if terminate | truncate: + observation = np.zeros(env.observation_space.shape) + reward = 0 + else: + observation, reward, terminate, truncate, info = env.step(action) + + observations[i] = observation + rewards[i] = reward + terminates[i] = terminate + truncates[i] = truncate + + fitnesses += rewards + + return fitnesses diff --git a/evox_adaptor/neat.py b/evox_adaptor/neat.py new file mode 100644 index 0000000..55a6f12 --- /dev/null +++ b/evox_adaptor/neat.py @@ -0,0 +1,91 @@ +import jax.numpy as jnp + +import evox +from algorithms import neat +from configs import Configer + + +@evox.jit_class +class NEAT(evox.Algorithm): + def __init__(self, config): + self.config = config # global config + self.jit_config = Configer.create_jit_config(config) + ( + self.randkey, + self.pop_nodes, + self.pop_cons, + self.species_info, + self.idx2species, + self.center_nodes, + self.center_cons, + self.generation, + self.next_node_key, + self.next_species_key, + ) = neat.initialize(config) + super().__init__() + + def setup(self, key): + return evox.State( + randkey=self.randkey, + pop_nodes=self.pop_nodes, + pop_cons=self.pop_cons, + species_info=self.species_info, + idx2species=self.idx2species, + center_nodes=self.center_nodes, + center_cons=self.center_cons, + generation=self.generation, + next_node_key=self.next_node_key, + next_species_key=self.next_species_key, + jit_config=self.jit_config + ) + + def ask(self, state): + flatten_pop_nodes = state.pop_nodes.flatten() + flatten_pop_cons = state.pop_cons.flatten() + pop = jnp.concatenate([flatten_pop_nodes, flatten_pop_cons]) + return pop, state + + def tell(self, state, fitness): + + # evox is a minimization framework, so we need to negate the fitness + fitness = -fitness + + ( + randkey, + pop_nodes, + pop_cons, + species_info, + idx2species, + center_nodes, + center_cons, + generation, + next_node_key, + next_species_key + ) = neat.tell( + fitness, + state.randkey, + state.pop_nodes, + state.pop_cons, + state.species_info, + state.idx2species, + state.center_nodes, + state.center_cons, + state.generation, + state.next_node_key, + state.next_species_key, + state.jit_config + ) + + return evox.State( + randkey=randkey, + pop_nodes=pop_nodes, + pop_cons=pop_cons, + species_info=species_info, + idx2species=idx2species, + center_nodes=center_nodes, + center_cons=center_cons, + generation=generation, + next_node_key=next_node_key, + next_species_key=next_species_key, + jit_config=state.jit_config + ) diff --git a/examples/evox_/__init__.py b/examples/evox_/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/evox_/acrobot.ini b/examples/evox_/acrobot.ini new file mode 100644 index 0000000..f80e61d --- /dev/null +++ b/examples/evox_/acrobot.ini @@ -0,0 +1,22 @@ +[basic] +num_inputs = 6 +num_outputs = 3 +maximum_nodes = 50 +maximum_connections = 50 +maximum_species = 10 +forward_way = "single" +random_seed = 42 + +[population] +pop_size = 100 + +[gene-activation] +activation_default = "sigmoid" +activation_option_names = ['sigmoid', 'tanh', 'sin', 'gauss', 'relu', 'identity', 'inv', 'log', 'exp', 'abs', 'hat', 'square'] +activation_replace_rate = 0.1 + +[gene-aggregation] +aggregation_default = "sum" +aggregation_option_names = ['sum', 'product', 'max', 'min', 'maxabs', 'median', 'mean'] +aggregation_replace_rate = 0.1 + diff --git a/examples/evox_/acrobot.py b/examples/evox_/acrobot.py new file mode 100644 index 0000000..f96dd22 --- /dev/null +++ b/examples/evox_/acrobot.py @@ -0,0 +1,62 @@ +import evox +import jax +from jax import jit, vmap, numpy as jnp + +from configs import Configer +from algorithms.neat import create_forward_function, topological_sort, unflatten_connections +from evox_adaptor import NEAT, Gym + +if __name__ == '__main__': + batch_policy = True + key = jax.random.PRNGKey(42) + + monitor = evox.monitors.StdSOMonitor() + neat_config = Configer.load_config('acrobot.ini') + origin_forward_func = create_forward_function(neat_config) + + + def neat_transform(pop): + P = neat_config['pop_size'] + N = neat_config['maximum_nodes'] + C = neat_config['maximum_connections'] + + pop_nodes = pop[:P * N * 5].reshape((P, N, 5)) + pop_cons = pop[P * N * 5:].reshape((P, C, 4)) + + u_pop_cons = vmap(unflatten_connections)(pop_nodes, pop_cons) + pop_seqs = vmap(topological_sort)(pop_nodes, u_pop_cons) + return pop_seqs, pop_nodes, u_pop_cons + + # special policy for mountain car + def neat_forward(genome, x): + res = origin_forward_func(x, *genome) + out = jnp.argmax(res) # {0, 1, 2} + return out + + + forward_func = lambda pop, x: origin_forward_func(x, *pop) + + problem = Gym( + policy=jit(vmap(neat_forward)), + env_name="Acrobot-v1", + pop_size=100, + ) + + # create a pipeline + pipeline = evox.pipelines.StdPipeline( + algorithm=NEAT(neat_config), + problem=problem, + pop_transform=jit(neat_transform), + fitness_transform=monitor.record_fit, + ) + # init the pipeline + state = pipeline.init(key) + + # run the pipeline for 10 steps + for i in range(30): + state = pipeline.step(state) + print(i, monitor.get_min_fitness()) + + # obtain -62.0 + min_fitness = monitor.get_min_fitness() + print(min_fitness) diff --git a/examples/evox_/bipedalwalker.ini b/examples/evox_/bipedalwalker.ini new file mode 100644 index 0000000..9f271b5 --- /dev/null +++ b/examples/evox_/bipedalwalker.ini @@ -0,0 +1,22 @@ +[basic] +num_inputs = 24 +num_outputs = 4 +maximum_nodes = 100 +maximum_connections = 200 +maximum_species = 10 +forward_way = "single" +random_seed = 42 + +[population] +pop_size = 100 + +[gene-activation] +activation_default = "sigmoid" +activation_option_names = ['sigmoid', 'tanh', 'sin', 'gauss', 'relu', 'identity', 'inv', 'log', 'exp', 'abs', 'hat', 'square'] +activation_replace_rate = 0.1 + +[gene-aggregation] +aggregation_default = "sum" +aggregation_option_names = ['sum', 'product', 'max', 'min', 'maxabs', 'median', 'mean'] +aggregation_replace_rate = 0.1 + diff --git a/examples/evox_/bipedalwalker.py b/examples/evox_/bipedalwalker.py new file mode 100644 index 0000000..4abf1f3 --- /dev/null +++ b/examples/evox_/bipedalwalker.py @@ -0,0 +1,62 @@ +import evox +import jax +from jax import jit, vmap, numpy as jnp + +from configs import Configer +from algorithms.neat import create_forward_function, topological_sort, unflatten_connections +from evox_adaptor import NEAT, Gym + +if __name__ == '__main__': + batch_policy = True + key = jax.random.PRNGKey(42) + + monitor = evox.monitors.StdSOMonitor() + neat_config = Configer.load_config('bipedalwalker.ini') + origin_forward_func = create_forward_function(neat_config) + + + def neat_transform(pop): + P = neat_config['pop_size'] + N = neat_config['maximum_nodes'] + C = neat_config['maximum_connections'] + + pop_nodes = pop[:P * N * 5].reshape((P, N, 5)) + pop_cons = pop[P * N * 5:].reshape((P, C, 4)) + + u_pop_cons = vmap(unflatten_connections)(pop_nodes, pop_cons) + pop_seqs = vmap(topological_sort)(pop_nodes, u_pop_cons) + return pop_seqs, pop_nodes, u_pop_cons + + # special policy for mountain car + def neat_forward(genome, x): + res = origin_forward_func(x, *genome) + out = jnp.tanh(res) # (-1, 1) + return out + + + forward_func = lambda pop, x: origin_forward_func(x, *pop) + + problem = Gym( + policy=jit(vmap(neat_forward)), + env_name="BipedalWalker-v3", + pop_size=100, + ) + + # create a pipeline + pipeline = evox.pipelines.StdPipeline( + algorithm=NEAT(neat_config), + problem=problem, + pop_transform=jit(neat_transform), + fitness_transform=monitor.record_fit, + ) + # init the pipeline + state = pipeline.init(key) + + # run the pipeline for 10 steps + for i in range(30): + state = pipeline.step(state) + print(i, monitor.get_min_fitness()) + + # obtain 98.91529684268514 + min_fitness = monitor.get_min_fitness() + print(min_fitness) diff --git a/examples/evox_/cartpole.ini b/examples/evox_/cartpole.ini new file mode 100644 index 0000000..a5ba7e4 --- /dev/null +++ b/examples/evox_/cartpole.ini @@ -0,0 +1,11 @@ +[basic] +num_inputs = 4 +num_outputs = 1 +maximum_nodes = 50 +maximum_connections = 50 +maximum_species = 10 +forward_way = "single" +random_seed = 42 + +[population] +pop_size = 40 \ No newline at end of file diff --git a/examples/evox_/cartpole.py b/examples/evox_/cartpole.py new file mode 100644 index 0000000..cec596b --- /dev/null +++ b/examples/evox_/cartpole.py @@ -0,0 +1,62 @@ +import evox +import jax +from jax import jit, vmap, numpy as jnp + +from configs import Configer +from algorithms.neat import create_forward_function, topological_sort, unflatten_connections +from evox_adaptor import NEAT, Gym + +if __name__ == '__main__': + batch_policy = True + key = jax.random.PRNGKey(42) + + monitor = evox.monitors.StdSOMonitor() + neat_config = Configer.load_config('cartpole.ini') + origin_forward_func = create_forward_function(neat_config) + + + def neat_transform(pop): + P = neat_config['pop_size'] + N = neat_config['maximum_nodes'] + C = neat_config['maximum_connections'] + + pop_nodes = pop[:P * N * 5].reshape((P, N, 5)) + pop_cons = pop[P * N * 5:].reshape((P, C, 4)) + + u_pop_cons = vmap(unflatten_connections)(pop_nodes, pop_cons) + pop_seqs = vmap(topological_sort)(pop_nodes, u_pop_cons) + return pop_seqs, pop_nodes, u_pop_cons + + # special policy for cartpole + def neat_forward(genome, x): + res = origin_forward_func(x, *genome)[0] + out = jnp.where(res > 0.5, 1, 0) + return out + + + forward_func = lambda pop, x: origin_forward_func(x, *pop) + + problem = Gym( + policy=jit(vmap(neat_forward)), + env_name="CartPole-v1", + pop_size=40, + ) + + # create a pipeline + pipeline = evox.pipelines.StdPipeline( + algorithm=NEAT(neat_config), + problem=problem, + pop_transform=jit(neat_transform), + fitness_transform=monitor.record_fit, + ) + # init the pipeline + state = pipeline.init(key) + + # run the pipeline for 10 steps + for i in range(10): + state = pipeline.step(state) + print(monitor.get_min_fitness()) + + # obtain 500 + min_fitness = monitor.get_min_fitness() + print(min_fitness) diff --git a/examples/evox_/mountain_car.ini b/examples/evox_/mountain_car.ini new file mode 100644 index 0000000..21cb7d8 --- /dev/null +++ b/examples/evox_/mountain_car.ini @@ -0,0 +1,22 @@ +[basic] +num_inputs = 2 +num_outputs = 1 +maximum_nodes = 50 +maximum_connections = 50 +maximum_species = 10 +forward_way = "single" +random_seed = 42 + +[population] +pop_size = 100 + +[gene-activation] +activation_default = "sigmoid" +activation_option_names = ['sigmoid', 'tanh', 'sin', 'gauss', 'relu', 'identity', 'inv', 'log', 'exp', 'abs', 'hat', 'square'] +activation_replace_rate = 0.1 + +[gene-aggregation] +aggregation_default = "sum" +aggregation_option_names = ['sum', 'product', 'max', 'min', 'maxabs', 'median', 'mean'] +aggregation_replace_rate = 0.1 + diff --git a/examples/evox_/mountain_car.py b/examples/evox_/mountain_car.py new file mode 100644 index 0000000..9fcd66f --- /dev/null +++ b/examples/evox_/mountain_car.py @@ -0,0 +1,62 @@ +import evox +import jax +from jax import jit, vmap, numpy as jnp + +from configs import Configer +from algorithms.neat import create_forward_function, topological_sort, unflatten_connections +from evox_adaptor import NEAT, Gym + +if __name__ == '__main__': + batch_policy = True + key = jax.random.PRNGKey(42) + + monitor = evox.monitors.StdSOMonitor() + neat_config = Configer.load_config('mountain_car.ini') + origin_forward_func = create_forward_function(neat_config) + + + def neat_transform(pop): + P = neat_config['pop_size'] + N = neat_config['maximum_nodes'] + C = neat_config['maximum_connections'] + + pop_nodes = pop[:P * N * 5].reshape((P, N, 5)) + pop_cons = pop[P * N * 5:].reshape((P, C, 4)) + + u_pop_cons = vmap(unflatten_connections)(pop_nodes, pop_cons) + pop_seqs = vmap(topological_sort)(pop_nodes, u_pop_cons) + return pop_seqs, pop_nodes, u_pop_cons + + # special policy for mountain car + def neat_forward(genome, x): + res = origin_forward_func(x, *genome) + out = jnp.tanh(res) # (-1, 1) + return out + + + forward_func = lambda pop, x: origin_forward_func(x, *pop) + + problem = Gym( + policy=jit(vmap(neat_forward)), + env_name="MountainCarContinuous-v0", + pop_size=100, + ) + + # create a pipeline + pipeline = evox.pipelines.StdPipeline( + algorithm=NEAT(neat_config), + problem=problem, + pop_transform=jit(neat_transform), + fitness_transform=monitor.record_fit, + ) + # init the pipeline + state = pipeline.init(key) + + # run the pipeline for 10 steps + for i in range(30): + state = pipeline.step(state) + print(i, monitor.get_min_fitness()) + + # obtain 98.91529684268514 + min_fitness = monitor.get_min_fitness() + print(min_fitness) diff --git a/examples/xor3d.ini b/examples/xor3d.ini index 85b41af..e883a52 100644 --- a/examples/xor3d.ini +++ b/examples/xor3d.ini @@ -12,7 +12,7 @@ random_seed = 42 fitness_threshold = 8 generation_limit = 1000 fitness_criterion = "max" -pop_size = 100000 +pop_size = 10000 [genome] compatibility_disjoint = 1.0 diff --git a/pipeline.py b/pipeline.py index 24782d7..46b0a47 100644 --- a/pipeline.py +++ b/pipeline.py @@ -27,28 +27,23 @@ class Pipeline: self.evaluate_time = 0 - - self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \ - self.center_cons, self.generation, self.next_node_key, self.next_species_key = neat.initialize(config) - + ( + self.randkey, + self.pop_nodes, + self.pop_cons, + self.species_info, + self.idx2species, + self.center_nodes, + self.center_cons, + self.generation, + self.next_node_key, + self.next_species_key, + ) = neat.initialize(config) self.forward = neat.create_forward_function(config) self.pop_unflatten_connections = jit(vmap(neat.unflatten_connections)) self.pop_topological_sort = jit(vmap(neat.topological_sort)) - # self.tell_func = neat.tell.lower(np.zeros(config['pop_size'], dtype=np.float32), - # self.randkey, - # self.pop_nodes, - # self.pop_cons, - # self.species_info, - # self.idx2species, - # self.center_nodes, - # self.center_cons, - # self.generation, - # self.next_node_key, - # self.next_species_key, - # self.jit_config).compile() - def ask(self): """ Creates a function that receives a genome and returns a forward function. @@ -77,21 +72,31 @@ class Pipeline: return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons) def tell(self, fitness): - - self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \ - self.center_cons, self.generation, self.next_node_key, self.next_species_key = neat.tell(fitness, - self.randkey, - self.pop_nodes, - self.pop_cons, - self.species_info, - self.idx2species, - self.center_nodes, - self.center_cons, - self.generation, - self.next_node_key, - self.next_species_key, - self.jit_config) - + ( + self.randkey, + self.pop_nodes, + self.pop_cons, + self.species_info, + self.idx2species, + self.center_nodes, + self.center_cons, + self.generation, + self.next_node_key, + self.next_species_key, + ) = neat.tell( + fitness, + self.randkey, + self.pop_nodes, + self.pop_cons, + self.species_info, + self.idx2species, + self.center_nodes, + self.center_cons, + self.generation, + self.next_node_key, + self.next_species_key, + self.jit_config + ) def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"): for _ in range(self.config['generation_limit']):