odify genome for the official release
This commit is contained in:
39
examples/brax/ant.py
Normal file
39
examples/brax/ant.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import BraxEnv
|
||||
from tensorneat.common import Act
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=DefaultGenome(
|
||||
num_inputs=27,
|
||||
num_outputs=8,
|
||||
max_nodes=100,
|
||||
max_conns=200,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_options=(Act.tanh,),
|
||||
activation_default=Act.tanh,
|
||||
),
|
||||
output_transform=Act.tanh,
|
||||
),
|
||||
pop_size=1000,
|
||||
species_size=10,
|
||||
compatibility_threshold=3.5,
|
||||
survival_threshold=0.01,
|
||||
),
|
||||
),
|
||||
problem=BraxEnv(
|
||||
env_name="ant",
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=5000,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
48
examples/brax/half_cheetah.py
Normal file
48
examples/brax/half_cheetah.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import jax
|
||||
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import BraxEnv
|
||||
from tensorneat.common import Act
|
||||
|
||||
|
||||
def sample_policy(randkey, obs):
|
||||
return jax.random.uniform(randkey, (6,), minval=-1, maxval=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=DefaultGenome(
|
||||
num_inputs=17,
|
||||
num_outputs=6,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_options=(Act.tanh,),
|
||||
activation_default=Act.tanh,
|
||||
),
|
||||
output_transform=Act.tanh,
|
||||
),
|
||||
pop_size=1000,
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=BraxEnv(
|
||||
env_name="halfcheetah",
|
||||
max_step=1000,
|
||||
obs_normalization=True,
|
||||
sample_episodes=1000,
|
||||
sample_policy=sample_policy,
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=5000,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
37
examples/brax/reacher.py
Normal file
37
examples/brax/reacher.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import BraxEnv
|
||||
from tensorneat.common import Act
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=DefaultGenome(
|
||||
num_inputs=11,
|
||||
num_outputs=2,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_options=(Act.tanh,),
|
||||
activation_default=Act.tanh,
|
||||
),
|
||||
output_transform=Act.tanh,
|
||||
),
|
||||
pop_size=100,
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=BraxEnv(
|
||||
env_name="reacher",
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=5000,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
19
examples/brax/show_test.py
Normal file
19
examples/brax/show_test.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import jax
|
||||
from problem.rl_env import BraxEnv
|
||||
|
||||
|
||||
def random_policy(randkey, forward_func, obs):
|
||||
return jax.random.uniform(randkey, (6,), minval=-1, maxval=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
problem = BraxEnv(env_name="walker2d", max_step=1000, action_policy=random_policy)
|
||||
state = problem.setup()
|
||||
randkey = jax.random.key(0)
|
||||
problem.show(
|
||||
state,
|
||||
randkey,
|
||||
act_func=lambda state, params, obs: obs,
|
||||
params=None,
|
||||
save_path="walker2d_random_policy",
|
||||
)
|
||||
60
examples/brax/walker.py
Normal file
60
examples/brax/walker.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import BraxEnv
|
||||
from tensorneat.common import Act
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
|
||||
def split_right_left(randkey, forward_func, obs):
|
||||
right_obs_keys = jnp.array([2, 3, 4, 11, 12, 13])
|
||||
left_obs_keys = jnp.array([5, 6, 7, 14, 15, 16])
|
||||
right_action_keys = jnp.array([0, 1, 2])
|
||||
left_action_keys = jnp.array([3, 4, 5])
|
||||
|
||||
right_foot_obs = obs
|
||||
left_foot_obs = obs
|
||||
left_foot_obs = left_foot_obs.at[right_obs_keys].set(obs[left_obs_keys])
|
||||
left_foot_obs = left_foot_obs.at[left_obs_keys].set(obs[right_obs_keys])
|
||||
|
||||
right_action, left_action = jax.vmap(forward_func)(jnp.stack([right_foot_obs, left_foot_obs]))
|
||||
# print(right_action.shape)
|
||||
# print(left_action.shape)
|
||||
|
||||
return jnp.concatenate([right_action, left_action])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=DefaultGenome(
|
||||
num_inputs=17,
|
||||
num_outputs=3,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_options=(Act.tanh,),
|
||||
activation_default=Act.tanh,
|
||||
),
|
||||
output_transform=Act.tanh,
|
||||
),
|
||||
pop_size=1000,
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=BraxEnv(
|
||||
env_name="walker2d",
|
||||
max_step=1000,
|
||||
action_policy=split_right_left
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=5000,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
50
examples/func_fit/xor.py
Normal file
50
examples/func_fit/xor.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.func_fit import XOR3d
|
||||
from tensorneat.common import ACT_ALL, AGG_ALL, Act, Agg
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=DenseInitialize(
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_default=Act.tanh,
|
||||
# activation_options=(Act.tanh,),
|
||||
activation_options=ACT_ALL,
|
||||
aggregation_default=Agg.sum,
|
||||
# aggregation_options=(Agg.sum,),
|
||||
aggregation_options=AGG_ALL,
|
||||
),
|
||||
output_transform=Act.standard_sigmoid, # the activation function for output node
|
||||
mutation=DefaultMutation(
|
||||
node_add=0.1,
|
||||
conn_add=0.1,
|
||||
node_delete=0,
|
||||
conn_delete=0,
|
||||
),
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=20,
|
||||
compatibility_threshold=2,
|
||||
survival_threshold=0.01, # magic
|
||||
),
|
||||
),
|
||||
problem=XOR3d(),
|
||||
generation_limit=10000,
|
||||
fitness_target=-1e-3,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
# show result
|
||||
pipeline.show(state, best)
|
||||
pipeline.save(state=state)
|
||||
63
examples/func_fit/xor3d_hyperneat.py
Normal file
63
examples/func_fit/xor3d_hyperneat.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
from algorithm.hyperneat import *
|
||||
from tensorneat.common import Act
|
||||
|
||||
from problem.func_fit import XOR3d
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=HyperNEAT(
|
||||
substrate=FullSubstrate(
|
||||
input_coors=[(-1, -1), (0.333, -1), (-0.333, -1), (1, -1)], # 3(XOR3d inputs) + 1(bias)
|
||||
hidden_coors=[
|
||||
(-1, -0.5), (0.333, -0.5), (-0.333, -0.5),
|
||||
(1, -0.5),
|
||||
(-1, 0),
|
||||
(0.333, 0),
|
||||
(-0.333, 0),
|
||||
(1, 0),
|
||||
(-1, 0.5),
|
||||
(0.333, 0.5),
|
||||
(-0.333, 0.5),
|
||||
(1, 0.5),
|
||||
],
|
||||
output_coors=[
|
||||
(0, 1), # one output
|
||||
],
|
||||
),
|
||||
neat=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=DefaultGenome(
|
||||
num_inputs=4, # [*coor1, *coor2]
|
||||
num_outputs=1, # the weight of connection between two coor1 and coor2
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_default=Act.tanh,
|
||||
activation_options=(Act.tanh,),
|
||||
),
|
||||
output_transform=Act.tanh, # the activation function for output node in NEAT
|
||||
),
|
||||
pop_size=1000,
|
||||
species_size=10,
|
||||
compatibility_threshold=2,
|
||||
survival_threshold=0.03,
|
||||
),
|
||||
),
|
||||
activation=Act.tanh,
|
||||
activate_time=10,
|
||||
output_transform=Act.sigmoid, # the activation function for output node in HyperNEAT
|
||||
),
|
||||
problem=XOR3d(),
|
||||
generation_limit=300,
|
||||
fitness_target=-1e-6,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
# show result
|
||||
pipeline.show(state, best)
|
||||
47
examples/func_fit/xor_recurrent.py
Normal file
47
examples/func_fit/xor_recurrent.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse
|
||||
|
||||
from problem.func_fit import XOR3d
|
||||
from utils.activation import ACT_ALL, Act
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
seed=0,
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=RecurrentGenome(
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
activate_time=5,
|
||||
node_gene=NodeGeneWithoutResponse(
|
||||
activation_options=ACT_ALL, activation_replace_rate=0.2
|
||||
),
|
||||
output_transform=Act.sigmoid,
|
||||
mutation=DefaultMutation(
|
||||
node_add=0.05,
|
||||
conn_add=0.2,
|
||||
node_delete=0,
|
||||
conn_delete=0,
|
||||
),
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
compatibility_threshold=3.5,
|
||||
survival_threshold=0.03,
|
||||
),
|
||||
),
|
||||
problem=XOR3d(),
|
||||
generation_limit=10000,
|
||||
fitness_target=-1e-8,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
# show result
|
||||
pipeline.show(state, best)
|
||||
36
examples/gymnax/arcbot.py
Normal file
36
examples/gymnax/arcbot.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import GymNaxEnv
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=DefaultGenome(
|
||||
num_inputs=6,
|
||||
num_outputs=3,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
output_transform=lambda out: jnp.argmax(
|
||||
out
|
||||
), # the action of acrobot is {0, 1, 2}
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=GymNaxEnv(
|
||||
env_name="Acrobot-v1",
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=-62,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
41
examples/gymnax/cartpole.py
Normal file
41
examples/gymnax/cartpole.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import GymNaxEnv
|
||||
|
||||
|
||||
def action_policy(randkey, forward_func, obs):
|
||||
return jnp.argmax(forward_func(obs))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=DefaultGenome(
|
||||
num_inputs=4,
|
||||
num_outputs=2,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
# output_transform=lambda out: jnp.argmax(
|
||||
# out
|
||||
# ), # the action of cartpole is {0, 1}
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=GymNaxEnv(
|
||||
env_name="CartPole-v1", repeat_times=5, action_policy=action_policy
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=500,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
74
examples/gymnax/cartpole_hyperneat.py
Normal file
74
examples/gymnax/cartpole_hyperneat.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import jax
|
||||
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
from algorithm.hyperneat import *
|
||||
from tensorneat.common import Act
|
||||
|
||||
from problem.rl_env import GymNaxEnv
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=HyperNEAT(
|
||||
substrate=FullSubstrate(
|
||||
input_coors=[
|
||||
(-1, -1),
|
||||
(-0.5, -1),
|
||||
(0, -1),
|
||||
(0.5, -1),
|
||||
(1, -1),
|
||||
], # 4(problem inputs) + 1(bias)
|
||||
hidden_coors=[
|
||||
(-1, -0.5),
|
||||
(0.333, -0.5),
|
||||
(-0.333, -0.5),
|
||||
(1, -0.5),
|
||||
(-1, 0),
|
||||
(0.333, 0),
|
||||
(-0.333, 0),
|
||||
(1, 0),
|
||||
(-1, 0.5),
|
||||
(0.333, 0.5),
|
||||
(-0.333, 0.5),
|
||||
(1, 0.5),
|
||||
],
|
||||
output_coors=[
|
||||
(-1, 1),
|
||||
(1, 1), # one output
|
||||
],
|
||||
),
|
||||
neat=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=DefaultGenome(
|
||||
num_inputs=4, # [*coor1, *coor2]
|
||||
num_outputs=1, # the weight of connection between two coor1 and coor2
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_default=Act.tanh,
|
||||
activation_options=(Act.tanh,),
|
||||
),
|
||||
output_transform=Act.tanh, # the activation function for output node in NEAT
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
compatibility_threshold=3.5,
|
||||
survival_threshold=0.03,
|
||||
),
|
||||
),
|
||||
activation=Act.tanh, # the activation function for output node in HyperNEAT
|
||||
activate_time=10,
|
||||
output_transform=jax.numpy.argmax, # action of cartpole is in {0, 1}
|
||||
),
|
||||
problem=GymNaxEnv(
|
||||
env_name="CartPole-v1",
|
||||
),
|
||||
generation_limit=300,
|
||||
fitness_target=500,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
36
examples/gymnax/mountain_car.py
Normal file
36
examples/gymnax/mountain_car.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import GymNaxEnv
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=DefaultGenome(
|
||||
num_inputs=2,
|
||||
num_outputs=3,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
output_transform=lambda out: jnp.argmax(
|
||||
out
|
||||
), # the action of mountain car is {0, 1, 2}
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=GymNaxEnv(
|
||||
env_name="MountainCar-v0",
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=-86,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
37
examples/gymnax/mountain_car_continuous.py
Normal file
37
examples/gymnax/mountain_car_continuous.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import GymNaxEnv
|
||||
from tensorneat.common import Act
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=DefaultGenome(
|
||||
num_inputs=2,
|
||||
num_outputs=1,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_options=(Act.tanh,),
|
||||
activation_default=Act.tanh,
|
||||
),
|
||||
output_transform=Act.tanh
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=GymNaxEnv(
|
||||
env_name="MountainCarContinuous-v0",
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=99,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
38
examples/gymnax/pendulum.py
Normal file
38
examples/gymnax/pendulum.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import GymNaxEnv
|
||||
from tensorneat.common import Act
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=DefaultGenome(
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_options=(Act.tanh,),
|
||||
activation_default=Act.tanh,
|
||||
),
|
||||
output_transform=lambda out: Act.tanh(out)
|
||||
* 2, # the action of pendulum is [-2, 2]
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=GymNaxEnv(
|
||||
env_name="Pendulum-v1",
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=-10,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
33
examples/gymnax/reacher.py
Normal file
33
examples/gymnax/reacher.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import GymNaxEnv
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=DefaultGenome(
|
||||
num_inputs=8,
|
||||
num_outputs=2,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=GymNaxEnv(
|
||||
env_name="Reacher-misc",
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=90,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
488
examples/interpret_visualize/genome_sympy.ipynb
Normal file
488
examples/interpret_visualize/genome_sympy.ipynb
Normal file
File diff suppressed because one or more lines are too long
37
examples/interpret_visualize/genome_sympy.py
Normal file
37
examples/interpret_visualize/genome_sympy.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from algorithm.neat import *
|
||||
from algorithm.neat.genome.dense import DenseInitialize
|
||||
from utils.graph import topological_sort_python
|
||||
from tensorneat.common import *
|
||||
|
||||
if __name__ == "__main__":
|
||||
genome = DenseInitialize(
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
max_nodes=50,
|
||||
max_conns=500,
|
||||
)
|
||||
|
||||
state = genome.setup()
|
||||
|
||||
randkey = jax.random.PRNGKey(42)
|
||||
nodes, conns = genome.initialize(state, randkey)
|
||||
|
||||
network = genome.network_dict(state, nodes, conns)
|
||||
|
||||
input_idx, output_idx = genome.get_input_idx(), genome.get_output_idx()
|
||||
|
||||
res = genome.sympy_func(state, network, sympy_input_transform=lambda x: 999999999*x, sympy_output_transform=SympyStandardSigmoid)
|
||||
(symbols,
|
||||
args_symbols,
|
||||
input_symbols,
|
||||
nodes_exprs,
|
||||
output_exprs,
|
||||
forward_func,) = res
|
||||
|
||||
print(symbols)
|
||||
print(output_exprs[0].subs(args_symbols))
|
||||
|
||||
inputs = jnp.zeros(3)
|
||||
print(forward_func(inputs))
|
||||
2455
examples/interpret_visualize/graph.svg
Normal file
2455
examples/interpret_visualize/graph.svg
Normal file
File diff suppressed because it is too large
Load Diff
|
After Width: | Height: | Size: 90 KiB |
191
examples/interpret_visualize/network.json
Normal file
191
examples/interpret_visualize/network.json
Normal file
@@ -0,0 +1,191 @@
|
||||
{
|
||||
"nodes": {
|
||||
"0": {
|
||||
"bias": 0.13710324466228485,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"1": {
|
||||
"bias": -1.4202250242233276,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"2": {
|
||||
"bias": -0.4653860926628113,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"3": {
|
||||
"bias": 0.5835710167884827,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"4": {
|
||||
"bias": 2.187405824661255,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"5": {
|
||||
"bias": 0.24963024258613586,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"6": {
|
||||
"bias": -0.966821551322937,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"7": {
|
||||
"bias": 0.4452081620693207,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"8": {
|
||||
"bias": -0.07293166220188141,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"9": {
|
||||
"bias": -0.1625899225473404,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"10": {
|
||||
"bias": -0.8576332330703735,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"11": {
|
||||
"bias": -0.18487468361854553,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"12": {
|
||||
"bias": 1.4335486888885498,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"13": {
|
||||
"bias": -0.8690621256828308,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"14": {
|
||||
"bias": -0.23014676570892334,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"15": {
|
||||
"bias": 0.7880322337150574,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"16": {
|
||||
"bias": -0.22258250415325165,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"17": {
|
||||
"bias": 0.2773352861404419,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"18": {
|
||||
"bias": -0.40279051661491394,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"19": {
|
||||
"bias": 1.092000961303711,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"20": {
|
||||
"bias": -0.4063087999820709,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"21": {
|
||||
"bias": 0.3895529806613922,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"22": {
|
||||
"bias": -0.18007506430149078,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"23": {
|
||||
"bias": -0.8112533092498779,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"24": {
|
||||
"bias": 0.2946726381778717,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"25": {
|
||||
"bias": -1.118497371673584,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"26": {
|
||||
"bias": 1.3674490451812744,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"27": {
|
||||
"bias": -1.6514816284179688,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"28": {
|
||||
"bias": 0.9440701603889465,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"29": {
|
||||
"bias": 1.564852237701416,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
},
|
||||
"30": {
|
||||
"bias": -0.5568665266036987,
|
||||
"res": 1.0,
|
||||
"agg": "sum",
|
||||
"act": "sigmoid"
|
||||
}
|
||||
},
|
||||
"conns": {
|
||||
|
||||
2455
examples/interpret_visualize/network.svg
Normal file
2455
examples/interpret_visualize/network.svg
Normal file
File diff suppressed because it is too large
Load Diff
|
After Width: | Height: | Size: 89 KiB |
103
examples/interpret_visualize/visualize_genome.ipynb
Normal file
103
examples/interpret_visualize/visualize_genome.ipynb
Normal file
File diff suppressed because one or more lines are too long
13
examples/interpret_visualize/visualize_genome.py
Normal file
13
examples/interpret_visualize/visualize_genome.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# 创建一个空白的有向图
|
||||
G = nx.DiGraph()
|
||||
|
||||
# 添加边
|
||||
G.add_edge('A', 'B')
|
||||
G.add_edge('A', 'C')
|
||||
G.add_edge('B', 'C')
|
||||
G.add_edge('C', 'D')
|
||||
|
||||
# 绘制有向图
|
||||
25
examples/jumanji/2048_random_policy.py
Normal file
25
examples/jumanji/2048_random_policy.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import jax, jax.numpy as jnp
|
||||
import jax.random
|
||||
from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
|
||||
|
||||
def random_policy(state, params, obs):
|
||||
key = jax.random.key(obs.sum())
|
||||
actions = jax.random.normal(key, (4,))
|
||||
# actions = actions.at[2:].set(-9999)
|
||||
# return jnp.array([4, 4, 0, 1])
|
||||
# return jnp.array([1, 2, 3, 4])
|
||||
# return actions
|
||||
return actions
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
problem = Jumanji_2048(
|
||||
max_step=10000, repeat_times=1000, guarantee_invalid_action=False
|
||||
)
|
||||
state = problem.setup()
|
||||
jit_evaluate = jax.jit(
|
||||
lambda state, randkey: problem.evaluate(state, randkey, random_policy, None)
|
||||
)
|
||||
randkey = jax.random.PRNGKey(0)
|
||||
reward = jit_evaluate(state, randkey)
|
||||
print(reward)
|
||||
1874
examples/jumanji/2048_test.ipynb
Normal file
1874
examples/jumanji/2048_test.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
120
examples/jumanji/train_2048.py
Normal file
120
examples/jumanji/train_2048.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse
|
||||
from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
|
||||
from tensorneat.common import Act, Agg
|
||||
|
||||
|
||||
def rot_li(li):
|
||||
return li[1:] + [li[0]]
|
||||
|
||||
|
||||
def rot_boards(board):
|
||||
def rot(a, _):
|
||||
a = jnp.rot90(a)
|
||||
return a, a # carry, y
|
||||
|
||||
# carry, np.stack(ys)
|
||||
_, boards = jax.lax.scan(rot, board, jnp.arange(4, dtype=jnp.int32))
|
||||
return boards
|
||||
|
||||
|
||||
direction = ["up", "right", "down", "left"]
|
||||
lr_flip_direction = ["up", "left", "down", "right"]
|
||||
|
||||
directions = []
|
||||
lr_flip_directions = []
|
||||
for _ in range(4):
|
||||
direction = rot_li(direction)
|
||||
lr_flip_direction = rot_li(lr_flip_direction)
|
||||
directions.append(direction.copy())
|
||||
lr_flip_directions.append(lr_flip_direction.copy())
|
||||
|
||||
full_directions = directions + lr_flip_directions
|
||||
|
||||
|
||||
def action_policy(forward_func, obs):
|
||||
board = obs.reshape(4, 4)
|
||||
lr_flip_board = jnp.fliplr(board)
|
||||
|
||||
boards = rot_boards(board)
|
||||
lr_flip_boards = rot_boards(lr_flip_board)
|
||||
# stack
|
||||
full_boards = jnp.concatenate([boards, lr_flip_boards], axis=0)
|
||||
scores = jax.vmap(forward_func)(full_boards.reshape(8, -1))
|
||||
total_score = {"up": 0, "right": 0, "down": 0, "left": 0}
|
||||
for i in range(8):
|
||||
dire = full_directions[i]
|
||||
for j in range(4):
|
||||
total_score[dire[j]] += scores[i, j]
|
||||
|
||||
return jnp.array(
|
||||
[
|
||||
total_score["up"],
|
||||
total_score["right"],
|
||||
total_score["down"],
|
||||
total_score["left"],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=DefaultGenome(
|
||||
num_inputs=16,
|
||||
num_outputs=4,
|
||||
max_nodes=100,
|
||||
max_conns=1000,
|
||||
node_gene=NodeGeneWithoutResponse(
|
||||
activation_default=Act.sigmoid,
|
||||
activation_options=(
|
||||
Act.sigmoid,
|
||||
Act.relu,
|
||||
Act.tanh,
|
||||
Act.identity,
|
||||
),
|
||||
aggregation_default=Agg.sum,
|
||||
aggregation_options=(Agg.sum, ),
|
||||
activation_replace_rate=0.02,
|
||||
aggregation_replace_rate=0.02,
|
||||
bias_mutate_rate=0.03,
|
||||
bias_init_std=0.5,
|
||||
bias_mutate_power=0.02,
|
||||
bias_replace_rate=0.01,
|
||||
),
|
||||
conn_gene=DefaultConnGene(
|
||||
weight_mutate_rate=0.015,
|
||||
weight_replace_rate=0.03,
|
||||
weight_mutate_power=0.05,
|
||||
),
|
||||
mutation=DefaultMutation(node_add=0.001, conn_add=0.002),
|
||||
),
|
||||
pop_size=1000,
|
||||
species_size=5,
|
||||
survival_threshold=0.01,
|
||||
max_stagnation=7,
|
||||
genome_elitism=3,
|
||||
compatibility_threshold=1.2,
|
||||
),
|
||||
),
|
||||
problem=Jumanji_2048(
|
||||
max_step=1000,
|
||||
repeat_times=50,
|
||||
# guarantee_invalid_action=True,
|
||||
guarantee_invalid_action=False,
|
||||
action_policy=action_policy,
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=13000,
|
||||
save_path="2048.npz",
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
10
examples/tmp.py
Normal file
10
examples/tmp.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from tensorneat.algorithm import NEAT
|
||||
from tensorneat.algorithm.neat import DefaultGenome
|
||||
|
||||
key = jax.random.key(0)
|
||||
genome = DefaultGenome(num_inputs=5, num_outputs=3, init_hidden_layers=(1, ))
|
||||
state = genome.setup()
|
||||
nodes, conns = genome.initialize(state, key)
|
||||
print(genome.repr(state, nodes, conns))
|
||||
6
examples/with_evox/ray_test.py
Normal file
6
examples/with_evox/ray_test.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import ray
|
||||
|
||||
ray.init(num_gpus=2)
|
||||
|
||||
available_resources = ray.available_resources()
|
||||
print("Available resources:", available_resources)
|
||||
Reference in New Issue
Block a user