add action_policy for problem;

This commit is contained in:
wls2002
2024-06-07 17:09:16 +08:00
parent 10ec1c2df9
commit 3d5b80c6fa
13 changed files with 2417 additions and 1191 deletions

View File

@@ -16,7 +16,7 @@ if __name__ == "__main__":
max_nodes=50,
max_conns=100,
node_gene=KANNode(),
conn_gene=BSplineConn(grid_cnt=10),
conn_gene=BSplineConn(grid_cnt=6),
output_transform=Act.sigmoid, # the activation function for output node
mutation=DefaultMutation(
node_add=0.1,

View File

@@ -5,6 +5,11 @@ from algorithm.neat import *
from problem.rl_env import GymNaxEnv
def action_policy(forward_func, obs):
return jnp.argmax(forward_func(obs))
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
@@ -14,18 +19,15 @@ if __name__ == "__main__":
num_outputs=2,
max_nodes=50,
max_conns=100,
output_transform=lambda out: jnp.argmax(
out
), # the action of cartpole is {0, 1}
# output_transform=lambda out: jnp.argmax(
# out
# ), # the action of cartpole is {0, 1}
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name="CartPole-v1",
repeat_times=5
),
problem=GymNaxEnv(env_name="CartPole-v1", repeat_times=5, action_policy=action_policy),
generation_limit=10000,
fitness_target=500,
)

View File

@@ -1,46 +0,0 @@
import jax.numpy as jnp
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
from utils import Act, Agg
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=16,
num_outputs=4,
max_nodes=100,
max_conns=1000,
node_gene=DefaultNodeGene(
activation_default=Act.sigmoid,
activation_options=(Act.sigmoid, Act.relu, Act.tanh, Act.identity, Act.inv),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum, Agg.mean, Agg.max, Agg.product),
),
mutation=DefaultMutation(
node_add=0.03,
conn_add=0.03,
)
),
pop_size=10000,
species_size=100,
survival_threshold=0.01,
),
),
problem=Jumanji_2048(
max_step=10000,
repeat_times=5
),
generation_limit=10000,
fitness_target=13000,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -0,0 +1,25 @@
import jax, jax.numpy as jnp
import jax.random
from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
def random_policy(state, params, obs):
# key = jax.random.key(obs.sum())
# actions = jax.random.normal(key, (4,))
# actions = actions.at[2:].set(-9999)
return jnp.array([4, 4, 0, 1])
# return jnp.array([1, 2, 3, 4])
return actions
if __name__ == "__main__":
problem = Jumanji_2048(
max_step=10000, repeat_times=1000, guarantee_invalid_action=True
)
state = problem.setup()
jit_evaluate = jax.jit(
lambda state, randkey: problem.evaluate(state, randkey, random_policy, None)
)
randkey = jax.random.PRNGKey(0)
reward = jit_evaluate(state, randkey)
print(reward)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,119 @@
import jax, jax.numpy as jnp
from pipeline import Pipeline
from algorithm.neat import *
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse
from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
from utils import Act, Agg
def rot_li(li):
return li[1:] + [li[0]]
def rot_boards(board):
def rot(a, _):
a = jnp.rot90(a)
return a, a # carry, y
# carry, np.stack(ys)
_, boards = jax.lax.scan(rot, board, jnp.arange(4, dtype=jnp.int32))
return boards
direction = ["up", "right", "down", "left"]
lr_flip_direction = ["up", "left", "down", "right"]
directions = []
lr_flip_directions = []
for _ in range(4):
direction = rot_li(direction)
lr_flip_direction = rot_li(lr_flip_direction)
directions.append(direction.copy())
lr_flip_directions.append(lr_flip_direction.copy())
full_directions = directions + lr_flip_directions
def action_policy(forward_func, obs):
board = obs.reshape(4, 4)
lr_flip_board = jnp.fliplr(board)
boards = rot_boards(board)
lr_flip_boards = rot_boards(lr_flip_board)
# stack
full_boards = jnp.concatenate([boards, lr_flip_boards], axis=0)
scores = jax.vmap(forward_func)(full_boards.reshape(8, -1))
total_score = {"up": 0, "right": 0, "down": 0, "left": 0}
for i in range(8):
dire = full_directions[i]
for j in range(4):
total_score[dire[j]] += scores[i, j]
return jnp.array(
[
total_score["up"],
total_score["right"],
total_score["down"],
total_score["left"],
]
)
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=16,
num_outputs=4,
max_nodes=100,
max_conns=1000,
node_gene=NodeGeneWithoutResponse(
activation_default=Act.sigmoid,
activation_options=(
Act.sigmoid,
Act.relu,
Act.tanh,
Act.identity,
),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum,),
activation_replace_rate=0.02,
aggregation_replace_rate=0.02,
bias_mutate_rate=0.03,
bias_init_std=0.5,
bias_mutate_power=0.2,
bias_replace_rate=0.01,
),
conn_gene=DefaultConnGene(
weight_mutate_rate=0.015,
weight_replace_rate=0.003,
weight_mutate_power=0.5,
),
mutation=DefaultMutation(node_add=0.001, conn_add=0.002),
),
pop_size=1000,
species_size=5,
survival_threshold=0.1,
max_stagnation=7,
genome_elitism=3,
compatibility_threshold=1.2,
),
),
problem=Jumanji_2048(
max_step=10000,
repeat_times=10,
guarantee_invalid_action=True,
action_policy=action_policy,
),
generation_limit=1000,
fitness_target=13000,
save_path="2048.npz",
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)