update some examples
This commit is contained in:
@@ -1,39 +0,0 @@
|
|||||||
from pipeline import Pipeline
|
|
||||||
from algorithm.neat import *
|
|
||||||
|
|
||||||
from problem.rl_env import BraxEnv
|
|
||||||
from tensorneat.common import Act
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pipeline = Pipeline(
|
|
||||||
algorithm=NEAT(
|
|
||||||
species=DefaultSpecies(
|
|
||||||
genome=DefaultGenome(
|
|
||||||
num_inputs=27,
|
|
||||||
num_outputs=8,
|
|
||||||
max_nodes=100,
|
|
||||||
max_conns=200,
|
|
||||||
node_gene=DefaultNodeGene(
|
|
||||||
activation_options=(Act.tanh,),
|
|
||||||
activation_default=Act.tanh,
|
|
||||||
),
|
|
||||||
output_transform=Act.tanh,
|
|
||||||
),
|
|
||||||
pop_size=1000,
|
|
||||||
species_size=10,
|
|
||||||
compatibility_threshold=3.5,
|
|
||||||
survival_threshold=0.01,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
problem=BraxEnv(
|
|
||||||
env_name="ant",
|
|
||||||
),
|
|
||||||
generation_limit=10000,
|
|
||||||
fitness_target=5000,
|
|
||||||
)
|
|
||||||
|
|
||||||
# initialize state
|
|
||||||
state = pipeline.setup()
|
|
||||||
# print(state)
|
|
||||||
# run until terminate
|
|
||||||
state, best = pipeline.auto_run(state)
|
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
import jax
|
|
||||||
|
|
||||||
from pipeline import Pipeline
|
|
||||||
from algorithm.neat import *
|
|
||||||
|
|
||||||
from problem.rl_env import BraxEnv
|
|
||||||
from tensorneat.common import Act
|
|
||||||
|
|
||||||
|
|
||||||
def sample_policy(randkey, obs):
|
|
||||||
return jax.random.uniform(randkey, (6,), minval=-1, maxval=1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pipeline = Pipeline(
|
|
||||||
algorithm=NEAT(
|
|
||||||
species=DefaultSpecies(
|
|
||||||
genome=DefaultGenome(
|
|
||||||
num_inputs=17,
|
|
||||||
num_outputs=6,
|
|
||||||
max_nodes=50,
|
|
||||||
max_conns=100,
|
|
||||||
node_gene=DefaultNodeGene(
|
|
||||||
activation_options=(Act.tanh,),
|
|
||||||
activation_default=Act.tanh,
|
|
||||||
),
|
|
||||||
output_transform=Act.tanh,
|
|
||||||
),
|
|
||||||
pop_size=1000,
|
|
||||||
species_size=10,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
problem=BraxEnv(
|
|
||||||
env_name="halfcheetah",
|
|
||||||
max_step=1000,
|
|
||||||
obs_normalization=True,
|
|
||||||
sample_episodes=1000,
|
|
||||||
sample_policy=sample_policy,
|
|
||||||
),
|
|
||||||
generation_limit=10000,
|
|
||||||
fitness_target=5000,
|
|
||||||
)
|
|
||||||
|
|
||||||
# initialize state
|
|
||||||
state = pipeline.setup()
|
|
||||||
# print(state)
|
|
||||||
# run until terminate
|
|
||||||
state, best = pipeline.auto_run(state)
|
|
||||||
51
examples/brax/halfcheetah.py
Normal file
51
examples/brax/halfcheetah.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
from tensorneat.pipeline import Pipeline
|
||||||
|
from tensorneat.algorithm.neat import NEAT
|
||||||
|
from tensorneat.genome import DefaultGenome, BiasNode
|
||||||
|
|
||||||
|
from tensorneat.problem.rl import BraxEnv
|
||||||
|
from tensorneat.common import Act, Agg
|
||||||
|
|
||||||
|
import jax
|
||||||
|
|
||||||
|
|
||||||
|
def random_sample_policy(randkey, obs):
|
||||||
|
return jax.random.uniform(randkey, (6,), minval=-1.0, maxval=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pipeline = Pipeline(
|
||||||
|
algorithm=NEAT(
|
||||||
|
pop_size=1000,
|
||||||
|
species_size=20,
|
||||||
|
survival_threshold=0.1,
|
||||||
|
compatibility_threshold=1.0,
|
||||||
|
genome=DefaultGenome(
|
||||||
|
max_nodes=100,
|
||||||
|
max_conns=200,
|
||||||
|
num_inputs=17,
|
||||||
|
num_outputs=6,
|
||||||
|
init_hidden_layers=(),
|
||||||
|
node_gene=BiasNode(
|
||||||
|
activation_options=Act.tanh,
|
||||||
|
aggregation_options=Agg.sum,
|
||||||
|
),
|
||||||
|
output_transform=Act.standard_tanh,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
problem=BraxEnv(
|
||||||
|
env_name="halfcheetah",
|
||||||
|
max_step=1000,
|
||||||
|
obs_normalization=True,
|
||||||
|
sample_episodes=1000,
|
||||||
|
sample_policy=random_sample_policy,
|
||||||
|
),
|
||||||
|
seed=42,
|
||||||
|
generation_limit=100,
|
||||||
|
fitness_target=8000,
|
||||||
|
)
|
||||||
|
|
||||||
|
# initialize state
|
||||||
|
state = pipeline.setup()
|
||||||
|
# print(state)
|
||||||
|
# run until terminate
|
||||||
|
state, best = pipeline.auto_run(state)
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
from pipeline import Pipeline
|
|
||||||
from algorithm.neat import *
|
|
||||||
|
|
||||||
from problem.rl_env import BraxEnv
|
|
||||||
from tensorneat.common import Act
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pipeline = Pipeline(
|
|
||||||
algorithm=NEAT(
|
|
||||||
species=DefaultSpecies(
|
|
||||||
genome=DefaultGenome(
|
|
||||||
num_inputs=11,
|
|
||||||
num_outputs=2,
|
|
||||||
max_nodes=50,
|
|
||||||
max_conns=100,
|
|
||||||
node_gene=DefaultNodeGene(
|
|
||||||
activation_options=(Act.tanh,),
|
|
||||||
activation_default=Act.tanh,
|
|
||||||
),
|
|
||||||
output_transform=Act.tanh,
|
|
||||||
),
|
|
||||||
pop_size=100,
|
|
||||||
species_size=10,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
problem=BraxEnv(
|
|
||||||
env_name="reacher",
|
|
||||||
),
|
|
||||||
generation_limit=10000,
|
|
||||||
fitness_target=5000,
|
|
||||||
)
|
|
||||||
|
|
||||||
# initialize state
|
|
||||||
state = pipeline.setup()
|
|
||||||
# print(state)
|
|
||||||
# run until terminate
|
|
||||||
state, best = pipeline.auto_run(state)
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
import jax
|
|
||||||
from problem.rl_env import BraxEnv
|
|
||||||
|
|
||||||
|
|
||||||
def random_policy(randkey, forward_func, obs):
|
|
||||||
return jax.random.uniform(randkey, (6,), minval=-1, maxval=1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
problem = BraxEnv(env_name="walker2d", max_step=1000, action_policy=random_policy)
|
|
||||||
state = problem.setup()
|
|
||||||
randkey = jax.random.key(0)
|
|
||||||
problem.show(
|
|
||||||
state,
|
|
||||||
randkey,
|
|
||||||
act_func=lambda state, params, obs: obs,
|
|
||||||
params=None,
|
|
||||||
save_path="walker2d_random_policy",
|
|
||||||
)
|
|
||||||
@@ -9,7 +9,7 @@ import jax, jax.numpy as jnp
|
|||||||
|
|
||||||
|
|
||||||
def random_sample_policy(randkey, obs):
|
def random_sample_policy(randkey, obs):
|
||||||
return jax.random.uniform(randkey, (6,))
|
return jax.random.uniform(randkey, (6,), minval=-1.0, maxval=1.0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,36 +1,45 @@
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
from pipeline import Pipeline
|
from tensorneat.pipeline import Pipeline
|
||||||
from algorithm.neat import *
|
from tensorneat.algorithm.neat import NEAT
|
||||||
|
from tensorneat.genome import DefaultGenome, BiasNode
|
||||||
|
|
||||||
|
from tensorneat.problem.rl import GymNaxEnv
|
||||||
|
from tensorneat.common import Act, Agg
|
||||||
|
|
||||||
|
|
||||||
from problem.rl_env import GymNaxEnv
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# the network has 3 outputs, the max one will be the action
|
||||||
|
# as the action of acrobot is {0, 1, 2}
|
||||||
|
|
||||||
pipeline = Pipeline(
|
pipeline = Pipeline(
|
||||||
algorithm=NEAT(
|
algorithm=NEAT(
|
||||||
species=DefaultSpecies(
|
pop_size=1000,
|
||||||
genome=DefaultGenome(
|
species_size=20,
|
||||||
num_inputs=6,
|
survival_threshold=0.1,
|
||||||
num_outputs=3,
|
compatibility_threshold=1.0,
|
||||||
max_nodes=50,
|
genome=DefaultGenome(
|
||||||
max_conns=100,
|
num_inputs=6,
|
||||||
output_transform=lambda out: jnp.argmax(
|
num_outputs=3,
|
||||||
out
|
init_hidden_layers=(),
|
||||||
), # the action of acrobot is {0, 1, 2}
|
node_gene=BiasNode(
|
||||||
|
activation_options=Act.tanh,
|
||||||
|
aggregation_options=Agg.sum,
|
||||||
),
|
),
|
||||||
pop_size=10000,
|
output_transform=jnp.argmax,
|
||||||
species_size=10,
|
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
problem=GymNaxEnv(
|
problem=GymNaxEnv(
|
||||||
env_name="Acrobot-v1",
|
env_name="Acrobot-v1",
|
||||||
),
|
),
|
||||||
generation_limit=10000,
|
seed=42,
|
||||||
fitness_target=-62,
|
generation_limit=100,
|
||||||
|
fitness_target=-60,
|
||||||
)
|
)
|
||||||
|
|
||||||
# initialize state
|
# initialize state
|
||||||
state = pipeline.setup()
|
state = pipeline.setup()
|
||||||
# print(state)
|
|
||||||
# run until terminate
|
# run until terminate
|
||||||
state, best = pipeline.auto_run(state)
|
state, best = pipeline.auto_run(state)
|
||||||
|
|||||||
@@ -1,41 +1,46 @@
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
from pipeline import Pipeline
|
from tensorneat.pipeline import Pipeline
|
||||||
from algorithm.neat import *
|
from tensorneat.algorithm.neat import NEAT
|
||||||
|
from tensorneat.genome import DefaultGenome, BiasNode
|
||||||
|
|
||||||
from problem.rl_env import GymNaxEnv
|
from tensorneat.problem.rl import GymNaxEnv
|
||||||
|
from tensorneat.common import Act, Agg
|
||||||
|
|
||||||
|
|
||||||
def action_policy(randkey, forward_func, obs):
|
|
||||||
return jnp.argmax(forward_func(obs))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# the network has 2 outputs, the max one will be the action
|
||||||
|
# as the action of cartpole is {0, 1}
|
||||||
|
|
||||||
pipeline = Pipeline(
|
pipeline = Pipeline(
|
||||||
algorithm=NEAT(
|
algorithm=NEAT(
|
||||||
species=DefaultSpecies(
|
pop_size=1000,
|
||||||
genome=DefaultGenome(
|
species_size=20,
|
||||||
num_inputs=4,
|
survival_threshold=0.1,
|
||||||
num_outputs=2,
|
compatibility_threshold=1.0,
|
||||||
max_nodes=50,
|
genome=DefaultGenome(
|
||||||
max_conns=100,
|
num_inputs=4,
|
||||||
# output_transform=lambda out: jnp.argmax(
|
num_outputs=2,
|
||||||
# out
|
init_hidden_layers=(),
|
||||||
# ), # the action of cartpole is {0, 1}
|
node_gene=BiasNode(
|
||||||
|
activation_options=Act.tanh,
|
||||||
|
aggregation_options=Agg.sum,
|
||||||
),
|
),
|
||||||
pop_size=10000,
|
output_transform=jnp.argmax,
|
||||||
species_size=10,
|
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
problem=GymNaxEnv(
|
problem=GymNaxEnv(
|
||||||
env_name="CartPole-v1", repeat_times=5, action_policy=action_policy
|
env_name="CartPole-v1",
|
||||||
|
repeat_times=5,
|
||||||
),
|
),
|
||||||
generation_limit=10000,
|
seed=42,
|
||||||
|
generation_limit=100,
|
||||||
fitness_target=500,
|
fitness_target=500,
|
||||||
)
|
)
|
||||||
|
|
||||||
# initialize state
|
# initialize state
|
||||||
state = pipeline.setup()
|
state = pipeline.setup()
|
||||||
# print(state)
|
|
||||||
# run until terminate
|
# run until terminate
|
||||||
state, best = pipeline.auto_run(state)
|
state, best = pipeline.auto_run(state)
|
||||||
|
|||||||
@@ -1,70 +1,45 @@
|
|||||||
import jax
|
import jax.numpy as jnp
|
||||||
|
|
||||||
from pipeline import Pipeline
|
from tensorneat.pipeline import Pipeline
|
||||||
from algorithm.neat import *
|
from tensorneat.algorithm.neat import NEAT
|
||||||
from algorithm.hyperneat import *
|
from tensorneat.algorithm.hyperneat import HyperNEAT, FullSubstrate
|
||||||
|
from tensorneat.genome import DefaultGenome
|
||||||
from tensorneat.common import Act
|
from tensorneat.common import Act
|
||||||
|
|
||||||
from problem.rl_env import GymNaxEnv
|
from tensorneat.problem import GymNaxEnv
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
# the num of input_coors is 5
|
||||||
|
# 4 is for cartpole inputs, 1 is for bias
|
||||||
pipeline = Pipeline(
|
pipeline = Pipeline(
|
||||||
algorithm=HyperNEAT(
|
algorithm=HyperNEAT(
|
||||||
substrate=FullSubstrate(
|
substrate=FullSubstrate(
|
||||||
input_coors=[
|
input_coors=((-1, -1), (-0.5, -1), (0, -1), (0.5, -1), (1, -1)),
|
||||||
(-1, -1),
|
hidden_coors=((-1, 0), (0, 0), (1, 0)),
|
||||||
(-0.5, -1),
|
output_coors=((-1, 1), (1, 1)),
|
||||||
(0, -1),
|
|
||||||
(0.5, -1),
|
|
||||||
(1, -1),
|
|
||||||
], # 4(problem inputs) + 1(bias)
|
|
||||||
hidden_coors=[
|
|
||||||
(-1, -0.5),
|
|
||||||
(0.333, -0.5),
|
|
||||||
(-0.333, -0.5),
|
|
||||||
(1, -0.5),
|
|
||||||
(-1, 0),
|
|
||||||
(0.333, 0),
|
|
||||||
(-0.333, 0),
|
|
||||||
(1, 0),
|
|
||||||
(-1, 0.5),
|
|
||||||
(0.333, 0.5),
|
|
||||||
(-0.333, 0.5),
|
|
||||||
(1, 0.5),
|
|
||||||
],
|
|
||||||
output_coors=[
|
|
||||||
(-1, 1),
|
|
||||||
(1, 1), # one output
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
neat=NEAT(
|
neat=NEAT(
|
||||||
species=DefaultSpecies(
|
pop_size=10000,
|
||||||
genome=DefaultGenome(
|
species_size=20,
|
||||||
num_inputs=4, # [*coor1, *coor2]
|
survival_threshold=0.01,
|
||||||
num_outputs=1, # the weight of connection between two coor1 and coor2
|
genome=DefaultGenome(
|
||||||
max_nodes=50,
|
num_inputs=4, # size of query coors
|
||||||
max_conns=100,
|
num_outputs=1,
|
||||||
node_gene=DefaultNodeGene(
|
init_hidden_layers=(),
|
||||||
activation_default=Act.tanh,
|
output_transform=Act.standard_tanh,
|
||||||
activation_options=(Act.tanh,),
|
|
||||||
),
|
|
||||||
output_transform=Act.tanh, # the activation function for output node in NEAT
|
|
||||||
),
|
|
||||||
pop_size=10000,
|
|
||||||
species_size=10,
|
|
||||||
compatibility_threshold=3.5,
|
|
||||||
survival_threshold=0.03,
|
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
activation=Act.tanh, # the activation function for output node in HyperNEAT
|
activation=Act.tanh,
|
||||||
activate_time=10,
|
activate_time=10,
|
||||||
output_transform=jax.numpy.argmax, # action of cartpole is in {0, 1}
|
output_transform=jnp.argmax,
|
||||||
),
|
),
|
||||||
problem=GymNaxEnv(
|
problem=GymNaxEnv(
|
||||||
env_name="CartPole-v1",
|
env_name="CartPole-v1",
|
||||||
|
repeat_times=5,
|
||||||
),
|
),
|
||||||
generation_limit=300,
|
generation_limit=300,
|
||||||
fitness_target=500,
|
fitness_target=-1e-6,
|
||||||
)
|
)
|
||||||
|
|
||||||
# initialize state
|
# initialize state
|
||||||
|
|||||||
@@ -1,36 +0,0 @@
|
|||||||
import jax.numpy as jnp
|
|
||||||
|
|
||||||
from pipeline import Pipeline
|
|
||||||
from algorithm.neat import *
|
|
||||||
|
|
||||||
from problem.rl_env import GymNaxEnv
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pipeline = Pipeline(
|
|
||||||
algorithm=NEAT(
|
|
||||||
species=DefaultSpecies(
|
|
||||||
genome=DefaultGenome(
|
|
||||||
num_inputs=2,
|
|
||||||
num_outputs=3,
|
|
||||||
max_nodes=50,
|
|
||||||
max_conns=100,
|
|
||||||
output_transform=lambda out: jnp.argmax(
|
|
||||||
out
|
|
||||||
), # the action of mountain car is {0, 1, 2}
|
|
||||||
),
|
|
||||||
pop_size=10000,
|
|
||||||
species_size=10,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
problem=GymNaxEnv(
|
|
||||||
env_name="MountainCar-v0",
|
|
||||||
),
|
|
||||||
generation_limit=10000,
|
|
||||||
fitness_target=-86,
|
|
||||||
)
|
|
||||||
|
|
||||||
# initialize state
|
|
||||||
state = pipeline.setup()
|
|
||||||
# print(state)
|
|
||||||
# run until terminate
|
|
||||||
state, best = pipeline.auto_run(state)
|
|
||||||
@@ -1,37 +1,43 @@
|
|||||||
from pipeline import Pipeline
|
import jax.numpy as jnp
|
||||||
from algorithm.neat import *
|
|
||||||
|
from tensorneat.pipeline import Pipeline
|
||||||
|
from tensorneat.algorithm.neat import NEAT
|
||||||
|
from tensorneat.genome import DefaultGenome, BiasNode
|
||||||
|
|
||||||
|
from tensorneat.problem.rl import GymNaxEnv
|
||||||
|
from tensorneat.common import Act, Agg
|
||||||
|
|
||||||
|
|
||||||
from problem.rl_env import GymNaxEnv
|
|
||||||
from tensorneat.common import Act
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pipeline = Pipeline(
|
pipeline = Pipeline(
|
||||||
algorithm=NEAT(
|
algorithm=NEAT(
|
||||||
species=DefaultSpecies(
|
pop_size=1000,
|
||||||
genome=DefaultGenome(
|
species_size=20,
|
||||||
num_inputs=2,
|
survival_threshold=0.1,
|
||||||
num_outputs=1,
|
compatibility_threshold=1.0,
|
||||||
max_nodes=50,
|
genome=DefaultGenome(
|
||||||
max_conns=100,
|
num_inputs=2,
|
||||||
node_gene=DefaultNodeGene(
|
num_outputs=1,
|
||||||
activation_options=(Act.tanh,),
|
init_hidden_layers=(),
|
||||||
activation_default=Act.tanh,
|
node_gene=BiasNode(
|
||||||
),
|
activation_options=Act.tanh,
|
||||||
output_transform=Act.tanh
|
aggregation_options=Agg.sum,
|
||||||
),
|
),
|
||||||
pop_size=10000,
|
output_transform=Act.standard_tanh,
|
||||||
species_size=10,
|
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
problem=GymNaxEnv(
|
problem=GymNaxEnv(
|
||||||
env_name="MountainCarContinuous-v0",
|
env_name="MountainCarContinuous-v0",
|
||||||
|
repeat_times=5,
|
||||||
),
|
),
|
||||||
generation_limit=10000,
|
seed=42,
|
||||||
|
generation_limit=100,
|
||||||
fitness_target=99,
|
fitness_target=99,
|
||||||
)
|
)
|
||||||
|
|
||||||
# initialize state
|
# initialize state
|
||||||
state = pipeline.setup()
|
state = pipeline.setup()
|
||||||
# print(state)
|
|
||||||
# run until terminate
|
# run until terminate
|
||||||
state, best = pipeline.auto_run(state)
|
state, best = pipeline.auto_run(state)
|
||||||
|
|||||||
@@ -1,38 +0,0 @@
|
|||||||
from pipeline import Pipeline
|
|
||||||
from algorithm.neat import *
|
|
||||||
|
|
||||||
from problem.rl_env import GymNaxEnv
|
|
||||||
from tensorneat.common import Act
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pipeline = Pipeline(
|
|
||||||
algorithm=NEAT(
|
|
||||||
species=DefaultSpecies(
|
|
||||||
genome=DefaultGenome(
|
|
||||||
num_inputs=3,
|
|
||||||
num_outputs=1,
|
|
||||||
max_nodes=50,
|
|
||||||
max_conns=100,
|
|
||||||
node_gene=DefaultNodeGene(
|
|
||||||
activation_options=(Act.tanh,),
|
|
||||||
activation_default=Act.tanh,
|
|
||||||
),
|
|
||||||
output_transform=lambda out: Act.tanh(out)
|
|
||||||
* 2, # the action of pendulum is [-2, 2]
|
|
||||||
),
|
|
||||||
pop_size=10000,
|
|
||||||
species_size=10,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
problem=GymNaxEnv(
|
|
||||||
env_name="Pendulum-v1",
|
|
||||||
),
|
|
||||||
generation_limit=10000,
|
|
||||||
fitness_target=-10,
|
|
||||||
)
|
|
||||||
|
|
||||||
# initialize state
|
|
||||||
state = pipeline.setup()
|
|
||||||
# print(state)
|
|
||||||
# run until terminate
|
|
||||||
state, best = pipeline.auto_run(state)
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
import jax.numpy as jnp
|
|
||||||
|
|
||||||
from pipeline import Pipeline
|
|
||||||
from algorithm.neat import *
|
|
||||||
|
|
||||||
from problem.rl_env import GymNaxEnv
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pipeline = Pipeline(
|
|
||||||
algorithm=NEAT(
|
|
||||||
species=DefaultSpecies(
|
|
||||||
genome=DefaultGenome(
|
|
||||||
num_inputs=8,
|
|
||||||
num_outputs=2,
|
|
||||||
max_nodes=50,
|
|
||||||
max_conns=100,
|
|
||||||
),
|
|
||||||
pop_size=10000,
|
|
||||||
species_size=10,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
problem=GymNaxEnv(
|
|
||||||
env_name="Reacher-misc",
|
|
||||||
),
|
|
||||||
generation_limit=10000,
|
|
||||||
fitness_target=90,
|
|
||||||
)
|
|
||||||
|
|
||||||
# initialize state
|
|
||||||
state = pipeline.setup()
|
|
||||||
# print(state)
|
|
||||||
# run until terminate
|
|
||||||
state, best = pipeline.auto_run(state)
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
import jax, jax.numpy as jnp
|
|
||||||
import jax.random
|
|
||||||
from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
|
|
||||||
|
|
||||||
def random_policy(state, params, obs):
|
|
||||||
key = jax.random.key(obs.sum())
|
|
||||||
actions = jax.random.normal(key, (4,))
|
|
||||||
# actions = actions.at[2:].set(-9999)
|
|
||||||
# return jnp.array([4, 4, 0, 1])
|
|
||||||
# return jnp.array([1, 2, 3, 4])
|
|
||||||
# return actions
|
|
||||||
return actions
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
problem = Jumanji_2048(
|
|
||||||
max_step=10000, repeat_times=1000, guarantee_invalid_action=False
|
|
||||||
)
|
|
||||||
state = problem.setup()
|
|
||||||
jit_evaluate = jax.jit(
|
|
||||||
lambda state, randkey: problem.evaluate(state, randkey, random_policy, None)
|
|
||||||
)
|
|
||||||
randkey = jax.random.PRNGKey(0)
|
|
||||||
reward = jit_evaluate(state, randkey)
|
|
||||||
print(reward)
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,120 +0,0 @@
|
|||||||
import jax, jax.numpy as jnp
|
|
||||||
|
|
||||||
from pipeline import Pipeline
|
|
||||||
from algorithm.neat import *
|
|
||||||
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse
|
|
||||||
from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
|
|
||||||
from tensorneat.common import Act, Agg
|
|
||||||
|
|
||||||
|
|
||||||
def rot_li(li):
|
|
||||||
return li[1:] + [li[0]]
|
|
||||||
|
|
||||||
|
|
||||||
def rot_boards(board):
|
|
||||||
def rot(a, _):
|
|
||||||
a = jnp.rot90(a)
|
|
||||||
return a, a # carry, y
|
|
||||||
|
|
||||||
# carry, np.stack(ys)
|
|
||||||
_, boards = jax.lax.scan(rot, board, jnp.arange(4, dtype=jnp.int32))
|
|
||||||
return boards
|
|
||||||
|
|
||||||
|
|
||||||
direction = ["up", "right", "down", "left"]
|
|
||||||
lr_flip_direction = ["up", "left", "down", "right"]
|
|
||||||
|
|
||||||
directions = []
|
|
||||||
lr_flip_directions = []
|
|
||||||
for _ in range(4):
|
|
||||||
direction = rot_li(direction)
|
|
||||||
lr_flip_direction = rot_li(lr_flip_direction)
|
|
||||||
directions.append(direction.copy())
|
|
||||||
lr_flip_directions.append(lr_flip_direction.copy())
|
|
||||||
|
|
||||||
full_directions = directions + lr_flip_directions
|
|
||||||
|
|
||||||
|
|
||||||
def action_policy(forward_func, obs):
|
|
||||||
board = obs.reshape(4, 4)
|
|
||||||
lr_flip_board = jnp.fliplr(board)
|
|
||||||
|
|
||||||
boards = rot_boards(board)
|
|
||||||
lr_flip_boards = rot_boards(lr_flip_board)
|
|
||||||
# stack
|
|
||||||
full_boards = jnp.concatenate([boards, lr_flip_boards], axis=0)
|
|
||||||
scores = jax.vmap(forward_func)(full_boards.reshape(8, -1))
|
|
||||||
total_score = {"up": 0, "right": 0, "down": 0, "left": 0}
|
|
||||||
for i in range(8):
|
|
||||||
dire = full_directions[i]
|
|
||||||
for j in range(4):
|
|
||||||
total_score[dire[j]] += scores[i, j]
|
|
||||||
|
|
||||||
return jnp.array(
|
|
||||||
[
|
|
||||||
total_score["up"],
|
|
||||||
total_score["right"],
|
|
||||||
total_score["down"],
|
|
||||||
total_score["left"],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pipeline = Pipeline(
|
|
||||||
algorithm=NEAT(
|
|
||||||
species=DefaultSpecies(
|
|
||||||
genome=DefaultGenome(
|
|
||||||
num_inputs=16,
|
|
||||||
num_outputs=4,
|
|
||||||
max_nodes=100,
|
|
||||||
max_conns=1000,
|
|
||||||
node_gene=NodeGeneWithoutResponse(
|
|
||||||
activation_default=Act.sigmoid,
|
|
||||||
activation_options=(
|
|
||||||
Act.sigmoid,
|
|
||||||
Act.relu,
|
|
||||||
Act.tanh,
|
|
||||||
Act.identity,
|
|
||||||
),
|
|
||||||
aggregation_default=Agg.sum,
|
|
||||||
aggregation_options=(Agg.sum, ),
|
|
||||||
activation_replace_rate=0.02,
|
|
||||||
aggregation_replace_rate=0.02,
|
|
||||||
bias_mutate_rate=0.03,
|
|
||||||
bias_init_std=0.5,
|
|
||||||
bias_mutate_power=0.02,
|
|
||||||
bias_replace_rate=0.01,
|
|
||||||
),
|
|
||||||
conn_gene=DefaultConnGene(
|
|
||||||
weight_mutate_rate=0.015,
|
|
||||||
weight_replace_rate=0.03,
|
|
||||||
weight_mutate_power=0.05,
|
|
||||||
),
|
|
||||||
mutation=DefaultMutation(node_add=0.001, conn_add=0.002),
|
|
||||||
),
|
|
||||||
pop_size=1000,
|
|
||||||
species_size=5,
|
|
||||||
survival_threshold=0.01,
|
|
||||||
max_stagnation=7,
|
|
||||||
genome_elitism=3,
|
|
||||||
compatibility_threshold=1.2,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
problem=Jumanji_2048(
|
|
||||||
max_step=1000,
|
|
||||||
repeat_times=50,
|
|
||||||
# guarantee_invalid_action=True,
|
|
||||||
guarantee_invalid_action=False,
|
|
||||||
action_policy=action_policy,
|
|
||||||
),
|
|
||||||
generation_limit=10000,
|
|
||||||
fitness_target=13000,
|
|
||||||
save_path="2048.npz",
|
|
||||||
)
|
|
||||||
|
|
||||||
# initialize state
|
|
||||||
state = pipeline.setup()
|
|
||||||
# print(state)
|
|
||||||
# run until terminate
|
|
||||||
state, best = pipeline.auto_run(state)
|
|
||||||
Reference in New Issue
Block a user