Using Evox to deal with RL tasks! With distributed Gym environment!
Three simple tasks in Gym[classical] are tested.
This commit is contained in:
@@ -2,12 +2,16 @@ import jax
|
|||||||
from jax import Array, numpy as jnp, jit, vmap
|
from jax import Array, numpy as jnp, jit, vmap
|
||||||
|
|
||||||
from .utils import I_INT
|
from .utils import I_INT
|
||||||
|
from .activations import act_name2func
|
||||||
|
from .aggregations import agg_name2func
|
||||||
|
|
||||||
|
|
||||||
def create_forward_function(config):
|
def create_forward_function(config):
|
||||||
"""
|
"""
|
||||||
meta method to create forward function
|
meta method to create forward function
|
||||||
"""
|
"""
|
||||||
|
config['activation_funcs'] = [act_name2func[name] for name in config['activation_option_names']]
|
||||||
|
config['aggregation_funcs'] = [agg_name2func[name] for name in config['aggregation_option_names']]
|
||||||
|
|
||||||
def act(idx, z):
|
def act(idx, z):
|
||||||
"""
|
"""
|
||||||
@@ -92,12 +96,11 @@ def create_forward_function(config):
|
|||||||
common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0))
|
common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0))
|
||||||
|
|
||||||
if config['forward_way'] == 'single':
|
if config['forward_way'] == 'single':
|
||||||
return jit(batch_forward)
|
return jit(forward)
|
||||||
|
# return jit(batch_forward)
|
||||||
|
|
||||||
elif config['forward_way'] == 'pop':
|
elif config['forward_way'] == 'pop':
|
||||||
return jit(pop_batch_forward)
|
return jit(pop_batch_forward)
|
||||||
|
|
||||||
elif config['forward_way'] == 'common':
|
elif config['forward_way'] == 'common':
|
||||||
return jit(common_forward)
|
return jit(common_forward)
|
||||||
|
|
||||||
return jit(forward)
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Some graph algorithms implemented in jax.
|
Some graph algorithm implemented in jax.
|
||||||
Only used in feed-forward networks.
|
Only used in feed-forward networks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -4,9 +4,6 @@ import configparser
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from algorithms.neat.genome.activations import act_name2func
|
|
||||||
from algorithms.neat.genome.aggregations import agg_name2func
|
|
||||||
|
|
||||||
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
|
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
|
||||||
jit_config_keys = [
|
jit_config_keys = [
|
||||||
"input_idx",
|
"input_idx",
|
||||||
@@ -108,13 +105,11 @@ class Configer:
|
|||||||
def refactor_activation(cls, config):
|
def refactor_activation(cls, config):
|
||||||
config['activation_default'] = 0
|
config['activation_default'] = 0
|
||||||
config['activation_options'] = np.arange(len(config['activation_option_names']))
|
config['activation_options'] = np.arange(len(config['activation_option_names']))
|
||||||
config['activation_funcs'] = [act_name2func[name] for name in config['activation_option_names']]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def refactor_aggregation(cls, config):
|
def refactor_aggregation(cls, config):
|
||||||
config['aggregation_default'] = 0
|
config['aggregation_default'] = 0
|
||||||
config['aggregation_options'] = np.arange(len(config['aggregation_option_names']))
|
config['aggregation_options'] = np.arange(len(config['aggregation_option_names']))
|
||||||
config['aggregation_funcs'] = [agg_name2func[name] for name in config['aggregation_option_names']]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_jit_config(cls, config):
|
def create_jit_config(cls, config):
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ random_seed = 0
|
|||||||
fitness_threshold = 3.99999
|
fitness_threshold = 3.99999
|
||||||
generation_limit = 1000
|
generation_limit = 1000
|
||||||
fitness_criterion = "max"
|
fitness_criterion = "max"
|
||||||
pop_size = 100000
|
pop_size = 10000
|
||||||
|
|
||||||
[genome]
|
[genome]
|
||||||
compatibility_disjoint = 1.0
|
compatibility_disjoint = 1.0
|
||||||
|
|||||||
2
evox_adaptor/__init__.py
Normal file
2
evox_adaptor/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from .neat import NEAT
|
||||||
|
from .gym_no_distribution import Gym
|
||||||
83
evox_adaptor/gym_no_distribution.py
Normal file
83
evox_adaptor/gym_no_distribution.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from evox import Problem, State
|
||||||
|
|
||||||
|
|
||||||
|
class Gym(Problem):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pop_size: int,
|
||||||
|
policy: Callable,
|
||||||
|
env_name: str = "CartPole-v1",
|
||||||
|
env_options: dict = None,
|
||||||
|
batch_policy: bool = True,
|
||||||
|
):
|
||||||
|
self.pop_size = pop_size
|
||||||
|
self.env_name = env_name
|
||||||
|
self.policy = policy
|
||||||
|
self.env_options = env_options or {}
|
||||||
|
self.batch_policy = batch_policy
|
||||||
|
assert batch_policy, "Only batch policy is supported for now"
|
||||||
|
|
||||||
|
self.envs = [gym.make(env_name, **self.env_options) for _ in range(self.pop_size)]
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def setup(self, key):
|
||||||
|
return State(key=key)
|
||||||
|
|
||||||
|
def evaluate(self, state, pop):
|
||||||
|
key = state.key
|
||||||
|
# key, subkey = jax.random.split(state.key)
|
||||||
|
|
||||||
|
# generate a list of seeds for gym
|
||||||
|
# seeds = jax.random.randint(
|
||||||
|
# subkey, (self.pop_size,), 0, jnp.iinfo(jnp.int32).max
|
||||||
|
# )
|
||||||
|
|
||||||
|
# currently use fixed seed for debugging
|
||||||
|
seeds = jax.random.randint(
|
||||||
|
key, (self.pop_size,), 0, jnp.iinfo(jnp.int32).max
|
||||||
|
)
|
||||||
|
|
||||||
|
seeds = seeds.tolist() # seed must be a python int, not numpy array
|
||||||
|
|
||||||
|
fitnesses = self.__rollout(seeds, pop)
|
||||||
|
print("fitnesses info: ")
|
||||||
|
print(f"max: {np.max(fitnesses)}, min: {np.min(fitnesses)}, mean: {np.mean(fitnesses)}, std: {np.std(fitnesses)}")
|
||||||
|
|
||||||
|
# evox uses negative fitness for minimization
|
||||||
|
return -fitnesses, State(key=key)
|
||||||
|
|
||||||
|
def __rollout(self, seeds, pop):
|
||||||
|
observations, infos = zip(
|
||||||
|
*[env.reset(seed=seed) for env, seed in zip(self.envs, seeds)]
|
||||||
|
)
|
||||||
|
terminates, truncates = np.zeros((2, self.pop_size), dtype=bool)
|
||||||
|
fitnesses, rewards = np.zeros((2, self.pop_size))
|
||||||
|
|
||||||
|
while not np.all(terminates | truncates):
|
||||||
|
observations = np.asarray(observations)
|
||||||
|
actions = self.policy(pop, observations)
|
||||||
|
actions = jax.device_get(actions)
|
||||||
|
|
||||||
|
for i, (action, terminate, truncate, env) in enumerate(zip(actions, terminates, truncates, self.envs)):
|
||||||
|
if terminate | truncate:
|
||||||
|
observation = np.zeros(env.observation_space.shape)
|
||||||
|
reward = 0
|
||||||
|
else:
|
||||||
|
observation, reward, terminate, truncate, info = env.step(action)
|
||||||
|
|
||||||
|
observations[i] = observation
|
||||||
|
rewards[i] = reward
|
||||||
|
terminates[i] = terminate
|
||||||
|
truncates[i] = truncate
|
||||||
|
|
||||||
|
fitnesses += rewards
|
||||||
|
|
||||||
|
return fitnesses
|
||||||
91
evox_adaptor/neat.py
Normal file
91
evox_adaptor/neat.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
import evox
|
||||||
|
from algorithms import neat
|
||||||
|
from configs import Configer
|
||||||
|
|
||||||
|
|
||||||
|
@evox.jit_class
|
||||||
|
class NEAT(evox.Algorithm):
|
||||||
|
def __init__(self, config):
|
||||||
|
self.config = config # global config
|
||||||
|
self.jit_config = Configer.create_jit_config(config)
|
||||||
|
(
|
||||||
|
self.randkey,
|
||||||
|
self.pop_nodes,
|
||||||
|
self.pop_cons,
|
||||||
|
self.species_info,
|
||||||
|
self.idx2species,
|
||||||
|
self.center_nodes,
|
||||||
|
self.center_cons,
|
||||||
|
self.generation,
|
||||||
|
self.next_node_key,
|
||||||
|
self.next_species_key,
|
||||||
|
) = neat.initialize(config)
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def setup(self, key):
|
||||||
|
return evox.State(
|
||||||
|
randkey=self.randkey,
|
||||||
|
pop_nodes=self.pop_nodes,
|
||||||
|
pop_cons=self.pop_cons,
|
||||||
|
species_info=self.species_info,
|
||||||
|
idx2species=self.idx2species,
|
||||||
|
center_nodes=self.center_nodes,
|
||||||
|
center_cons=self.center_cons,
|
||||||
|
generation=self.generation,
|
||||||
|
next_node_key=self.next_node_key,
|
||||||
|
next_species_key=self.next_species_key,
|
||||||
|
jit_config=self.jit_config
|
||||||
|
)
|
||||||
|
|
||||||
|
def ask(self, state):
|
||||||
|
flatten_pop_nodes = state.pop_nodes.flatten()
|
||||||
|
flatten_pop_cons = state.pop_cons.flatten()
|
||||||
|
pop = jnp.concatenate([flatten_pop_nodes, flatten_pop_cons])
|
||||||
|
return pop, state
|
||||||
|
|
||||||
|
def tell(self, state, fitness):
|
||||||
|
|
||||||
|
# evox is a minimization framework, so we need to negate the fitness
|
||||||
|
fitness = -fitness
|
||||||
|
|
||||||
|
(
|
||||||
|
randkey,
|
||||||
|
pop_nodes,
|
||||||
|
pop_cons,
|
||||||
|
species_info,
|
||||||
|
idx2species,
|
||||||
|
center_nodes,
|
||||||
|
center_cons,
|
||||||
|
generation,
|
||||||
|
next_node_key,
|
||||||
|
next_species_key
|
||||||
|
) = neat.tell(
|
||||||
|
fitness,
|
||||||
|
state.randkey,
|
||||||
|
state.pop_nodes,
|
||||||
|
state.pop_cons,
|
||||||
|
state.species_info,
|
||||||
|
state.idx2species,
|
||||||
|
state.center_nodes,
|
||||||
|
state.center_cons,
|
||||||
|
state.generation,
|
||||||
|
state.next_node_key,
|
||||||
|
state.next_species_key,
|
||||||
|
state.jit_config
|
||||||
|
)
|
||||||
|
|
||||||
|
return evox.State(
|
||||||
|
randkey=randkey,
|
||||||
|
pop_nodes=pop_nodes,
|
||||||
|
pop_cons=pop_cons,
|
||||||
|
species_info=species_info,
|
||||||
|
idx2species=idx2species,
|
||||||
|
center_nodes=center_nodes,
|
||||||
|
center_cons=center_cons,
|
||||||
|
generation=generation,
|
||||||
|
next_node_key=next_node_key,
|
||||||
|
next_species_key=next_species_key,
|
||||||
|
jit_config=state.jit_config
|
||||||
|
)
|
||||||
0
examples/evox_/__init__.py
Normal file
0
examples/evox_/__init__.py
Normal file
22
examples/evox_/acrobot.ini
Normal file
22
examples/evox_/acrobot.ini
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
[basic]
|
||||||
|
num_inputs = 6
|
||||||
|
num_outputs = 3
|
||||||
|
maximum_nodes = 50
|
||||||
|
maximum_connections = 50
|
||||||
|
maximum_species = 10
|
||||||
|
forward_way = "single"
|
||||||
|
random_seed = 42
|
||||||
|
|
||||||
|
[population]
|
||||||
|
pop_size = 100
|
||||||
|
|
||||||
|
[gene-activation]
|
||||||
|
activation_default = "sigmoid"
|
||||||
|
activation_option_names = ['sigmoid', 'tanh', 'sin', 'gauss', 'relu', 'identity', 'inv', 'log', 'exp', 'abs', 'hat', 'square']
|
||||||
|
activation_replace_rate = 0.1
|
||||||
|
|
||||||
|
[gene-aggregation]
|
||||||
|
aggregation_default = "sum"
|
||||||
|
aggregation_option_names = ['sum', 'product', 'max', 'min', 'maxabs', 'median', 'mean']
|
||||||
|
aggregation_replace_rate = 0.1
|
||||||
|
|
||||||
62
examples/evox_/acrobot.py
Normal file
62
examples/evox_/acrobot.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import evox
|
||||||
|
import jax
|
||||||
|
from jax import jit, vmap, numpy as jnp
|
||||||
|
|
||||||
|
from configs import Configer
|
||||||
|
from algorithms.neat import create_forward_function, topological_sort, unflatten_connections
|
||||||
|
from evox_adaptor import NEAT, Gym
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
batch_policy = True
|
||||||
|
key = jax.random.PRNGKey(42)
|
||||||
|
|
||||||
|
monitor = evox.monitors.StdSOMonitor()
|
||||||
|
neat_config = Configer.load_config('acrobot.ini')
|
||||||
|
origin_forward_func = create_forward_function(neat_config)
|
||||||
|
|
||||||
|
|
||||||
|
def neat_transform(pop):
|
||||||
|
P = neat_config['pop_size']
|
||||||
|
N = neat_config['maximum_nodes']
|
||||||
|
C = neat_config['maximum_connections']
|
||||||
|
|
||||||
|
pop_nodes = pop[:P * N * 5].reshape((P, N, 5))
|
||||||
|
pop_cons = pop[P * N * 5:].reshape((P, C, 4))
|
||||||
|
|
||||||
|
u_pop_cons = vmap(unflatten_connections)(pop_nodes, pop_cons)
|
||||||
|
pop_seqs = vmap(topological_sort)(pop_nodes, u_pop_cons)
|
||||||
|
return pop_seqs, pop_nodes, u_pop_cons
|
||||||
|
|
||||||
|
# special policy for mountain car
|
||||||
|
def neat_forward(genome, x):
|
||||||
|
res = origin_forward_func(x, *genome)
|
||||||
|
out = jnp.argmax(res) # {0, 1, 2}
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
forward_func = lambda pop, x: origin_forward_func(x, *pop)
|
||||||
|
|
||||||
|
problem = Gym(
|
||||||
|
policy=jit(vmap(neat_forward)),
|
||||||
|
env_name="Acrobot-v1",
|
||||||
|
pop_size=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
# create a pipeline
|
||||||
|
pipeline = evox.pipelines.StdPipeline(
|
||||||
|
algorithm=NEAT(neat_config),
|
||||||
|
problem=problem,
|
||||||
|
pop_transform=jit(neat_transform),
|
||||||
|
fitness_transform=monitor.record_fit,
|
||||||
|
)
|
||||||
|
# init the pipeline
|
||||||
|
state = pipeline.init(key)
|
||||||
|
|
||||||
|
# run the pipeline for 10 steps
|
||||||
|
for i in range(30):
|
||||||
|
state = pipeline.step(state)
|
||||||
|
print(i, monitor.get_min_fitness())
|
||||||
|
|
||||||
|
# obtain -62.0
|
||||||
|
min_fitness = monitor.get_min_fitness()
|
||||||
|
print(min_fitness)
|
||||||
22
examples/evox_/bipedalwalker.ini
Normal file
22
examples/evox_/bipedalwalker.ini
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
[basic]
|
||||||
|
num_inputs = 24
|
||||||
|
num_outputs = 4
|
||||||
|
maximum_nodes = 100
|
||||||
|
maximum_connections = 200
|
||||||
|
maximum_species = 10
|
||||||
|
forward_way = "single"
|
||||||
|
random_seed = 42
|
||||||
|
|
||||||
|
[population]
|
||||||
|
pop_size = 100
|
||||||
|
|
||||||
|
[gene-activation]
|
||||||
|
activation_default = "sigmoid"
|
||||||
|
activation_option_names = ['sigmoid', 'tanh', 'sin', 'gauss', 'relu', 'identity', 'inv', 'log', 'exp', 'abs', 'hat', 'square']
|
||||||
|
activation_replace_rate = 0.1
|
||||||
|
|
||||||
|
[gene-aggregation]
|
||||||
|
aggregation_default = "sum"
|
||||||
|
aggregation_option_names = ['sum', 'product', 'max', 'min', 'maxabs', 'median', 'mean']
|
||||||
|
aggregation_replace_rate = 0.1
|
||||||
|
|
||||||
62
examples/evox_/bipedalwalker.py
Normal file
62
examples/evox_/bipedalwalker.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import evox
|
||||||
|
import jax
|
||||||
|
from jax import jit, vmap, numpy as jnp
|
||||||
|
|
||||||
|
from configs import Configer
|
||||||
|
from algorithms.neat import create_forward_function, topological_sort, unflatten_connections
|
||||||
|
from evox_adaptor import NEAT, Gym
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
batch_policy = True
|
||||||
|
key = jax.random.PRNGKey(42)
|
||||||
|
|
||||||
|
monitor = evox.monitors.StdSOMonitor()
|
||||||
|
neat_config = Configer.load_config('bipedalwalker.ini')
|
||||||
|
origin_forward_func = create_forward_function(neat_config)
|
||||||
|
|
||||||
|
|
||||||
|
def neat_transform(pop):
|
||||||
|
P = neat_config['pop_size']
|
||||||
|
N = neat_config['maximum_nodes']
|
||||||
|
C = neat_config['maximum_connections']
|
||||||
|
|
||||||
|
pop_nodes = pop[:P * N * 5].reshape((P, N, 5))
|
||||||
|
pop_cons = pop[P * N * 5:].reshape((P, C, 4))
|
||||||
|
|
||||||
|
u_pop_cons = vmap(unflatten_connections)(pop_nodes, pop_cons)
|
||||||
|
pop_seqs = vmap(topological_sort)(pop_nodes, u_pop_cons)
|
||||||
|
return pop_seqs, pop_nodes, u_pop_cons
|
||||||
|
|
||||||
|
# special policy for mountain car
|
||||||
|
def neat_forward(genome, x):
|
||||||
|
res = origin_forward_func(x, *genome)
|
||||||
|
out = jnp.tanh(res) # (-1, 1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
forward_func = lambda pop, x: origin_forward_func(x, *pop)
|
||||||
|
|
||||||
|
problem = Gym(
|
||||||
|
policy=jit(vmap(neat_forward)),
|
||||||
|
env_name="BipedalWalker-v3",
|
||||||
|
pop_size=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
# create a pipeline
|
||||||
|
pipeline = evox.pipelines.StdPipeline(
|
||||||
|
algorithm=NEAT(neat_config),
|
||||||
|
problem=problem,
|
||||||
|
pop_transform=jit(neat_transform),
|
||||||
|
fitness_transform=monitor.record_fit,
|
||||||
|
)
|
||||||
|
# init the pipeline
|
||||||
|
state = pipeline.init(key)
|
||||||
|
|
||||||
|
# run the pipeline for 10 steps
|
||||||
|
for i in range(30):
|
||||||
|
state = pipeline.step(state)
|
||||||
|
print(i, monitor.get_min_fitness())
|
||||||
|
|
||||||
|
# obtain 98.91529684268514
|
||||||
|
min_fitness = monitor.get_min_fitness()
|
||||||
|
print(min_fitness)
|
||||||
11
examples/evox_/cartpole.ini
Normal file
11
examples/evox_/cartpole.ini
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
[basic]
|
||||||
|
num_inputs = 4
|
||||||
|
num_outputs = 1
|
||||||
|
maximum_nodes = 50
|
||||||
|
maximum_connections = 50
|
||||||
|
maximum_species = 10
|
||||||
|
forward_way = "single"
|
||||||
|
random_seed = 42
|
||||||
|
|
||||||
|
[population]
|
||||||
|
pop_size = 40
|
||||||
62
examples/evox_/cartpole.py
Normal file
62
examples/evox_/cartpole.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import evox
|
||||||
|
import jax
|
||||||
|
from jax import jit, vmap, numpy as jnp
|
||||||
|
|
||||||
|
from configs import Configer
|
||||||
|
from algorithms.neat import create_forward_function, topological_sort, unflatten_connections
|
||||||
|
from evox_adaptor import NEAT, Gym
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
batch_policy = True
|
||||||
|
key = jax.random.PRNGKey(42)
|
||||||
|
|
||||||
|
monitor = evox.monitors.StdSOMonitor()
|
||||||
|
neat_config = Configer.load_config('cartpole.ini')
|
||||||
|
origin_forward_func = create_forward_function(neat_config)
|
||||||
|
|
||||||
|
|
||||||
|
def neat_transform(pop):
|
||||||
|
P = neat_config['pop_size']
|
||||||
|
N = neat_config['maximum_nodes']
|
||||||
|
C = neat_config['maximum_connections']
|
||||||
|
|
||||||
|
pop_nodes = pop[:P * N * 5].reshape((P, N, 5))
|
||||||
|
pop_cons = pop[P * N * 5:].reshape((P, C, 4))
|
||||||
|
|
||||||
|
u_pop_cons = vmap(unflatten_connections)(pop_nodes, pop_cons)
|
||||||
|
pop_seqs = vmap(topological_sort)(pop_nodes, u_pop_cons)
|
||||||
|
return pop_seqs, pop_nodes, u_pop_cons
|
||||||
|
|
||||||
|
# special policy for cartpole
|
||||||
|
def neat_forward(genome, x):
|
||||||
|
res = origin_forward_func(x, *genome)[0]
|
||||||
|
out = jnp.where(res > 0.5, 1, 0)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
forward_func = lambda pop, x: origin_forward_func(x, *pop)
|
||||||
|
|
||||||
|
problem = Gym(
|
||||||
|
policy=jit(vmap(neat_forward)),
|
||||||
|
env_name="CartPole-v1",
|
||||||
|
pop_size=40,
|
||||||
|
)
|
||||||
|
|
||||||
|
# create a pipeline
|
||||||
|
pipeline = evox.pipelines.StdPipeline(
|
||||||
|
algorithm=NEAT(neat_config),
|
||||||
|
problem=problem,
|
||||||
|
pop_transform=jit(neat_transform),
|
||||||
|
fitness_transform=monitor.record_fit,
|
||||||
|
)
|
||||||
|
# init the pipeline
|
||||||
|
state = pipeline.init(key)
|
||||||
|
|
||||||
|
# run the pipeline for 10 steps
|
||||||
|
for i in range(10):
|
||||||
|
state = pipeline.step(state)
|
||||||
|
print(monitor.get_min_fitness())
|
||||||
|
|
||||||
|
# obtain 500
|
||||||
|
min_fitness = monitor.get_min_fitness()
|
||||||
|
print(min_fitness)
|
||||||
22
examples/evox_/mountain_car.ini
Normal file
22
examples/evox_/mountain_car.ini
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
[basic]
|
||||||
|
num_inputs = 2
|
||||||
|
num_outputs = 1
|
||||||
|
maximum_nodes = 50
|
||||||
|
maximum_connections = 50
|
||||||
|
maximum_species = 10
|
||||||
|
forward_way = "single"
|
||||||
|
random_seed = 42
|
||||||
|
|
||||||
|
[population]
|
||||||
|
pop_size = 100
|
||||||
|
|
||||||
|
[gene-activation]
|
||||||
|
activation_default = "sigmoid"
|
||||||
|
activation_option_names = ['sigmoid', 'tanh', 'sin', 'gauss', 'relu', 'identity', 'inv', 'log', 'exp', 'abs', 'hat', 'square']
|
||||||
|
activation_replace_rate = 0.1
|
||||||
|
|
||||||
|
[gene-aggregation]
|
||||||
|
aggregation_default = "sum"
|
||||||
|
aggregation_option_names = ['sum', 'product', 'max', 'min', 'maxabs', 'median', 'mean']
|
||||||
|
aggregation_replace_rate = 0.1
|
||||||
|
|
||||||
62
examples/evox_/mountain_car.py
Normal file
62
examples/evox_/mountain_car.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import evox
|
||||||
|
import jax
|
||||||
|
from jax import jit, vmap, numpy as jnp
|
||||||
|
|
||||||
|
from configs import Configer
|
||||||
|
from algorithms.neat import create_forward_function, topological_sort, unflatten_connections
|
||||||
|
from evox_adaptor import NEAT, Gym
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
batch_policy = True
|
||||||
|
key = jax.random.PRNGKey(42)
|
||||||
|
|
||||||
|
monitor = evox.monitors.StdSOMonitor()
|
||||||
|
neat_config = Configer.load_config('mountain_car.ini')
|
||||||
|
origin_forward_func = create_forward_function(neat_config)
|
||||||
|
|
||||||
|
|
||||||
|
def neat_transform(pop):
|
||||||
|
P = neat_config['pop_size']
|
||||||
|
N = neat_config['maximum_nodes']
|
||||||
|
C = neat_config['maximum_connections']
|
||||||
|
|
||||||
|
pop_nodes = pop[:P * N * 5].reshape((P, N, 5))
|
||||||
|
pop_cons = pop[P * N * 5:].reshape((P, C, 4))
|
||||||
|
|
||||||
|
u_pop_cons = vmap(unflatten_connections)(pop_nodes, pop_cons)
|
||||||
|
pop_seqs = vmap(topological_sort)(pop_nodes, u_pop_cons)
|
||||||
|
return pop_seqs, pop_nodes, u_pop_cons
|
||||||
|
|
||||||
|
# special policy for mountain car
|
||||||
|
def neat_forward(genome, x):
|
||||||
|
res = origin_forward_func(x, *genome)
|
||||||
|
out = jnp.tanh(res) # (-1, 1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
forward_func = lambda pop, x: origin_forward_func(x, *pop)
|
||||||
|
|
||||||
|
problem = Gym(
|
||||||
|
policy=jit(vmap(neat_forward)),
|
||||||
|
env_name="MountainCarContinuous-v0",
|
||||||
|
pop_size=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
# create a pipeline
|
||||||
|
pipeline = evox.pipelines.StdPipeline(
|
||||||
|
algorithm=NEAT(neat_config),
|
||||||
|
problem=problem,
|
||||||
|
pop_transform=jit(neat_transform),
|
||||||
|
fitness_transform=monitor.record_fit,
|
||||||
|
)
|
||||||
|
# init the pipeline
|
||||||
|
state = pipeline.init(key)
|
||||||
|
|
||||||
|
# run the pipeline for 10 steps
|
||||||
|
for i in range(30):
|
||||||
|
state = pipeline.step(state)
|
||||||
|
print(i, monitor.get_min_fitness())
|
||||||
|
|
||||||
|
# obtain 98.91529684268514
|
||||||
|
min_fitness = monitor.get_min_fitness()
|
||||||
|
print(min_fitness)
|
||||||
@@ -12,7 +12,7 @@ random_seed = 42
|
|||||||
fitness_threshold = 8
|
fitness_threshold = 8
|
||||||
generation_limit = 1000
|
generation_limit = 1000
|
||||||
fitness_criterion = "max"
|
fitness_criterion = "max"
|
||||||
pop_size = 100000
|
pop_size = 10000
|
||||||
|
|
||||||
[genome]
|
[genome]
|
||||||
compatibility_disjoint = 1.0
|
compatibility_disjoint = 1.0
|
||||||
|
|||||||
49
pipeline.py
49
pipeline.py
@@ -27,28 +27,23 @@ class Pipeline:
|
|||||||
|
|
||||||
self.evaluate_time = 0
|
self.evaluate_time = 0
|
||||||
|
|
||||||
|
(
|
||||||
self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \
|
self.randkey,
|
||||||
self.center_cons, self.generation, self.next_node_key, self.next_species_key = neat.initialize(config)
|
self.pop_nodes,
|
||||||
|
self.pop_cons,
|
||||||
|
self.species_info,
|
||||||
|
self.idx2species,
|
||||||
|
self.center_nodes,
|
||||||
|
self.center_cons,
|
||||||
|
self.generation,
|
||||||
|
self.next_node_key,
|
||||||
|
self.next_species_key,
|
||||||
|
) = neat.initialize(config)
|
||||||
|
|
||||||
self.forward = neat.create_forward_function(config)
|
self.forward = neat.create_forward_function(config)
|
||||||
self.pop_unflatten_connections = jit(vmap(neat.unflatten_connections))
|
self.pop_unflatten_connections = jit(vmap(neat.unflatten_connections))
|
||||||
self.pop_topological_sort = jit(vmap(neat.topological_sort))
|
self.pop_topological_sort = jit(vmap(neat.topological_sort))
|
||||||
|
|
||||||
# self.tell_func = neat.tell.lower(np.zeros(config['pop_size'], dtype=np.float32),
|
|
||||||
# self.randkey,
|
|
||||||
# self.pop_nodes,
|
|
||||||
# self.pop_cons,
|
|
||||||
# self.species_info,
|
|
||||||
# self.idx2species,
|
|
||||||
# self.center_nodes,
|
|
||||||
# self.center_cons,
|
|
||||||
# self.generation,
|
|
||||||
# self.next_node_key,
|
|
||||||
# self.next_species_key,
|
|
||||||
# self.jit_config).compile()
|
|
||||||
|
|
||||||
def ask(self):
|
def ask(self):
|
||||||
"""
|
"""
|
||||||
Creates a function that receives a genome and returns a forward function.
|
Creates a function that receives a genome and returns a forward function.
|
||||||
@@ -77,9 +72,7 @@ class Pipeline:
|
|||||||
return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons)
|
return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons)
|
||||||
|
|
||||||
def tell(self, fitness):
|
def tell(self, fitness):
|
||||||
|
(
|
||||||
self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \
|
|
||||||
self.center_cons, self.generation, self.next_node_key, self.next_species_key = neat.tell(fitness,
|
|
||||||
self.randkey,
|
self.randkey,
|
||||||
self.pop_nodes,
|
self.pop_nodes,
|
||||||
self.pop_cons,
|
self.pop_cons,
|
||||||
@@ -90,8 +83,20 @@ class Pipeline:
|
|||||||
self.generation,
|
self.generation,
|
||||||
self.next_node_key,
|
self.next_node_key,
|
||||||
self.next_species_key,
|
self.next_species_key,
|
||||||
self.jit_config)
|
) = neat.tell(
|
||||||
|
fitness,
|
||||||
|
self.randkey,
|
||||||
|
self.pop_nodes,
|
||||||
|
self.pop_cons,
|
||||||
|
self.species_info,
|
||||||
|
self.idx2species,
|
||||||
|
self.center_nodes,
|
||||||
|
self.center_cons,
|
||||||
|
self.generation,
|
||||||
|
self.next_node_key,
|
||||||
|
self.next_species_key,
|
||||||
|
self.jit_config
|
||||||
|
)
|
||||||
|
|
||||||
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
||||||
for _ in range(self.config['generation_limit']):
|
for _ in range(self.config['generation_limit']):
|
||||||
|
|||||||
Reference in New Issue
Block a user