add action_policy for problem;

This commit is contained in:
wls2002
2024-06-07 17:09:16 +08:00
parent 10ec1c2df9
commit 3d5b80c6fa
13 changed files with 2417 additions and 1191 deletions

View File

@@ -16,7 +16,7 @@ if __name__ == "__main__":
max_nodes=50, max_nodes=50,
max_conns=100, max_conns=100,
node_gene=KANNode(), node_gene=KANNode(),
conn_gene=BSplineConn(grid_cnt=10), conn_gene=BSplineConn(grid_cnt=6),
output_transform=Act.sigmoid, # the activation function for output node output_transform=Act.sigmoid, # the activation function for output node
mutation=DefaultMutation( mutation=DefaultMutation(
node_add=0.1, node_add=0.1,

View File

@@ -5,6 +5,11 @@ from algorithm.neat import *
from problem.rl_env import GymNaxEnv from problem.rl_env import GymNaxEnv
def action_policy(forward_func, obs):
return jnp.argmax(forward_func(obs))
if __name__ == "__main__": if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(
algorithm=NEAT( algorithm=NEAT(
@@ -14,18 +19,15 @@ if __name__ == "__main__":
num_outputs=2, num_outputs=2,
max_nodes=50, max_nodes=50,
max_conns=100, max_conns=100,
output_transform=lambda out: jnp.argmax( # output_transform=lambda out: jnp.argmax(
out # out
), # the action of cartpole is {0, 1} # ), # the action of cartpole is {0, 1}
), ),
pop_size=10000, pop_size=10000,
species_size=10, species_size=10,
), ),
), ),
problem=GymNaxEnv( problem=GymNaxEnv(env_name="CartPole-v1", repeat_times=5, action_policy=action_policy),
env_name="CartPole-v1",
repeat_times=5
),
generation_limit=10000, generation_limit=10000,
fitness_target=500, fitness_target=500,
) )

View File

@@ -1,46 +0,0 @@
import jax.numpy as jnp
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
from utils import Act, Agg
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=16,
num_outputs=4,
max_nodes=100,
max_conns=1000,
node_gene=DefaultNodeGene(
activation_default=Act.sigmoid,
activation_options=(Act.sigmoid, Act.relu, Act.tanh, Act.identity, Act.inv),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum, Agg.mean, Agg.max, Agg.product),
),
mutation=DefaultMutation(
node_add=0.03,
conn_add=0.03,
)
),
pop_size=10000,
species_size=100,
survival_threshold=0.01,
),
),
problem=Jumanji_2048(
max_step=10000,
repeat_times=5
),
generation_limit=10000,
fitness_target=13000,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -0,0 +1,25 @@
import jax, jax.numpy as jnp
import jax.random
from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
def random_policy(state, params, obs):
# key = jax.random.key(obs.sum())
# actions = jax.random.normal(key, (4,))
# actions = actions.at[2:].set(-9999)
return jnp.array([4, 4, 0, 1])
# return jnp.array([1, 2, 3, 4])
return actions
if __name__ == "__main__":
problem = Jumanji_2048(
max_step=10000, repeat_times=1000, guarantee_invalid_action=True
)
state = problem.setup()
jit_evaluate = jax.jit(
lambda state, randkey: problem.evaluate(state, randkey, random_policy, None)
)
randkey = jax.random.PRNGKey(0)
reward = jit_evaluate(state, randkey)
print(reward)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,119 @@
import jax, jax.numpy as jnp
from pipeline import Pipeline
from algorithm.neat import *
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse
from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
from utils import Act, Agg
def rot_li(li):
return li[1:] + [li[0]]
def rot_boards(board):
def rot(a, _):
a = jnp.rot90(a)
return a, a # carry, y
# carry, np.stack(ys)
_, boards = jax.lax.scan(rot, board, jnp.arange(4, dtype=jnp.int32))
return boards
direction = ["up", "right", "down", "left"]
lr_flip_direction = ["up", "left", "down", "right"]
directions = []
lr_flip_directions = []
for _ in range(4):
direction = rot_li(direction)
lr_flip_direction = rot_li(lr_flip_direction)
directions.append(direction.copy())
lr_flip_directions.append(lr_flip_direction.copy())
full_directions = directions + lr_flip_directions
def action_policy(forward_func, obs):
board = obs.reshape(4, 4)
lr_flip_board = jnp.fliplr(board)
boards = rot_boards(board)
lr_flip_boards = rot_boards(lr_flip_board)
# stack
full_boards = jnp.concatenate([boards, lr_flip_boards], axis=0)
scores = jax.vmap(forward_func)(full_boards.reshape(8, -1))
total_score = {"up": 0, "right": 0, "down": 0, "left": 0}
for i in range(8):
dire = full_directions[i]
for j in range(4):
total_score[dire[j]] += scores[i, j]
return jnp.array(
[
total_score["up"],
total_score["right"],
total_score["down"],
total_score["left"],
]
)
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=16,
num_outputs=4,
max_nodes=100,
max_conns=1000,
node_gene=NodeGeneWithoutResponse(
activation_default=Act.sigmoid,
activation_options=(
Act.sigmoid,
Act.relu,
Act.tanh,
Act.identity,
),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum,),
activation_replace_rate=0.02,
aggregation_replace_rate=0.02,
bias_mutate_rate=0.03,
bias_init_std=0.5,
bias_mutate_power=0.2,
bias_replace_rate=0.01,
),
conn_gene=DefaultConnGene(
weight_mutate_rate=0.015,
weight_replace_rate=0.003,
weight_mutate_power=0.5,
),
mutation=DefaultMutation(node_add=0.001, conn_add=0.002),
),
pop_size=1000,
species_size=5,
survival_threshold=0.1,
max_stagnation=7,
genome_elitism=3,
compatibility_threshold=1.2,
),
),
problem=Jumanji_2048(
max_step=10000,
repeat_times=10,
guarantee_invalid_action=True,
action_policy=action_policy,
),
generation_limit=1000,
fitness_target=13000,
save_path="2048.npz",
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -19,6 +19,7 @@ class Pipeline:
generation_limit: int = 1000, generation_limit: int = 1000,
pre_update: bool = False, pre_update: bool = False,
update_batch_size: int = 10000, update_batch_size: int = 10000,
save_path=None,
): ):
assert problem.jitable, "Currently, problem must be jitable" assert problem.jitable, "Currently, problem must be jitable"
@@ -55,6 +56,7 @@ class Pipeline:
assert not problem.record_episode, "record_episode must be False" assert not problem.record_episode, "record_episode must be False"
elif isinstance(problem, FuncFit): elif isinstance(problem, FuncFit):
assert not problem.return_data, "return_data must be False" assert not problem.return_data, "return_data must be False"
self.save_path = save_path
def setup(self, state=State()): def setup(self, state=State()):
print("initializing") print("initializing")
@@ -181,6 +183,17 @@ class Pipeline:
self.best_fitness = fitnesses[max_idx] self.best_fitness = fitnesses[max_idx]
self.best_genome = pop[0][max_idx], pop[1][max_idx] self.best_genome = pop[0][max_idx], pop[1][max_idx]
# save best if save path is not None
if self.save_path is not None:
best_genome = jax.device_get(self.best_genome)
with open(self.save_path, "wb") as f:
np.savez(
f,
nodes=best_genome[0],
conns=best_genome[1],
fitness=self.best_fitness,
)
member_count = jax.device_get(self.algorithm.member_count(state)) member_count = jax.device_get(self.algorithm.member_count(state))
species_sizes = [int(i) for i in member_count if i > 0] species_sizes = [int(i) for i in member_count if i > 0]

View File

@@ -5,8 +5,10 @@ from .rl_jit import RLEnv
class BraxEnv(RLEnv): class BraxEnv(RLEnv):
def __init__(self, max_step=1000, repeat_times=1, record_episode=False, env_name: str = "ant", backend: str = "generalized"): def __init__(
super().__init__(max_step, repeat_times, record_episode) self, env_name: str = "ant", backend: str = "generalized", *args, **kwargs
):
super().__init__(*args, **kwargs)
self.env = envs.create(env_name=env_name, backend=backend) self.env = envs.create(env_name=env_name, backend=backend)
def env_step(self, randkey, env_state, action): def env_step(self, randkey, env_state, action):

View File

@@ -4,8 +4,8 @@ from .rl_jit import RLEnv
class GymNaxEnv(RLEnv): class GymNaxEnv(RLEnv):
def __init__(self, env_name, max_step=1000, repeat_times=1, record_episode=False): def __init__(self, env_name, *args, **kwargs):
super().__init__(max_step, repeat_times, record_episode) super().__init__(*args, **kwargs)
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered" assert env_name in gymnax.registered_envs, f"Env {env_name} not registered"
self.env, self.env_params = gymnax.make(env_name) self.env, self.env_params = gymnax.make(env_name)

View File

@@ -7,14 +7,21 @@ from ..rl_jit import RLEnv
class Jumanji_2048(RLEnv): class Jumanji_2048(RLEnv):
def __init__( def __init__(
self, max_step=1000, repeat_times=1, record_episode=False, guarantee_invalid_action=True self, guarantee_invalid_action=True, *args, **kwargs
): ):
super().__init__(max_step, repeat_times, record_episode) super().__init__(*args, **kwargs)
self.guarantee_invalid_action = guarantee_invalid_action self.guarantee_invalid_action = guarantee_invalid_action
self.env = jumanji.make("Game2048-v1") self.env = jumanji.make("Game2048-v1")
def env_step(self, randkey, env_state, action): def env_step(self, randkey, env_state, action):
action_mask = env_state["action_mask"] action_mask = env_state["action_mask"]
###################################################################
action = jnp.concatenate([action, jnp.full((4 - action.shape[0], ), -99999)])
action = (action - 1) / 15
###################################################################
if self.guarantee_invalid_action: if self.guarantee_invalid_action:
score_with_mask = jnp.where(action_mask, action, -jnp.inf) score_with_mask = jnp.where(action_mask, action, -jnp.inf)
action = jnp.argmax(score_with_mask) action = jnp.argmax(score_with_mask)

View File

@@ -11,11 +11,18 @@ from .. import BaseProblem
class RLEnv(BaseProblem): class RLEnv(BaseProblem):
jitable = True jitable = True
def __init__(self, max_step=1000, repeat_times=1, record_episode=False): def __init__(
self,
max_step=1000,
repeat_times=1,
record_episode=False,
action_policy: Callable = None,
):
super().__init__() super().__init__()
self.max_step = max_step self.max_step = max_step
self.record_episode = record_episode self.record_episode = record_episode
self.repeat_times = repeat_times self.repeat_times = repeat_times
self.action_policy = action_policy
def evaluate(self, state: State, randkey, act_func: Callable, params): def evaluate(self, state: State, randkey, act_func: Callable, params):
keys = jax.random.split(randkey, self.repeat_times) keys = jax.random.split(randkey, self.repeat_times)
@@ -63,6 +70,10 @@ class RLEnv(BaseProblem):
def body_func(carry): def body_func(carry):
obs, env_state, rng, done, tr, count, epis = carry # tr -> total reward obs, env_state, rng, done, tr, count, epis = carry # tr -> total reward
if self.action_policy is not None:
forward_func = lambda obs: act_func(state, params, obs)
action = self.action_policy(forward_func, obs)
else:
action = act_func(state, params, obs) action = act_func(state, params, obs)
next_obs, next_env_state, reward, done, _ = self.step( next_obs, next_env_state, reward, done, _ = self.step(
rng, env_state, action rng, env_state, action

File diff suppressed because one or more lines are too long

221
tensorneat/tmp.ipynb Normal file
View File

@@ -0,0 +1,221 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 22,
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2024-06-06T11:55:39.434327400Z",
"start_time": "2024-06-06T11:55:39.361327400Z"
}
},
"outputs": [
{
"data": {
"text/plain": "Array([[[2, 4],\n [1, 3]],\n\n [[4, 3],\n [2, 1]],\n\n [[3, 1],\n [4, 2]],\n\n [[1, 2],\n [3, 4]],\n\n [[2, 4],\n [1, 3]],\n\n [[4, 3],\n [2, 1]],\n\n [[3, 1],\n [4, 2]],\n\n [[1, 2],\n [3, 4]]], dtype=int32)"
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import jax, jax.numpy as jnp\n",
"a = jnp.array([\n",
" [1, 2],\n",
" [3, 4]\n",
"])\n",
"def rot_boards(board):\n",
" def rot(a, _):\n",
" a = jnp.rot90(a)\n",
" return a, a # carry, y\n",
" \n",
" _, boards = jax.lax.scan(rot, board, jnp.arange(4, dtype=jnp.int32))\n",
" return boards\n",
"a1 = rot_boards(a)\n",
"a2 = rot_boards(a)\n",
"\n",
"a = jnp.concatenate([a1, a2], axis=0)\n",
"a"
]
},
{
"cell_type": "code",
"execution_count": 21,
"outputs": [
{
"data": {
"text/plain": "Array([[2, 4, 1, 3],\n [4, 3, 2, 1],\n [3, 1, 4, 2],\n [1, 2, 3, 4],\n [2, 4, 1, 3],\n [4, 3, 2, 1],\n [3, 1, 4, 2],\n [1, 2, 3, 4]], dtype=int32)"
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = a.reshape(8, -1)\n",
"a"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-06T11:55:31.121054800Z",
"start_time": "2024-06-06T11:55:31.075517200Z"
}
},
"id": "639cdecea840351d"
},
{
"cell_type": "code",
"execution_count": 13,
"outputs": [],
"source": [
"action = [\"up\", \"right\", \"down\", \"left\"]\n",
"lr_flip_action = [\"up\", \"left\", \"down\", \"right\"]\n",
"def action_rot90(li):\n",
" first = li[0]\n",
" return li[1:] + [first]\n",
"\n",
"a = a\n",
"rl_flip_a = jnp.fliplr(a)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-06T11:22:36.417287600Z",
"start_time": "2024-06-06T11:22:36.414285500Z"
}
},
"id": "92b75cd0e870a28c"
},
{
"cell_type": "code",
"execution_count": 14,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[1 2]\n",
" [3 4]] ['up', 'right', 'down', 'left']\n",
"[[2 1]\n",
" [4 3]] ['up', 'left', 'down', 'right']\n",
"[[2 4]\n",
" [1 3]] ['right', 'down', 'left', 'up']\n",
"[[1 3]\n",
" [2 4]] ['left', 'down', 'right', 'up']\n",
"[[4 3]\n",
" [2 1]] ['down', 'left', 'up', 'right']\n",
"[[3 4]\n",
" [1 2]] ['down', 'right', 'up', 'left']\n",
"[[3 1]\n",
" [4 2]] ['left', 'up', 'right', 'down']\n",
"[[4 2]\n",
" [3 1]] ['right', 'up', 'left', 'down']\n"
]
}
],
"source": [
"for i in range(4):\n",
" print(a, action)\n",
" print(rl_flip_a, lr_flip_action)\n",
" a = jnp.rot90(a)\n",
" rl_flip_a = jnp.rot90(rl_flip_a)\n",
" action = action_rot90(action)\n",
" lr_flip_action = action_rot90(lr_flip_action)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-06T11:22:36.919614600Z",
"start_time": "2024-06-06T11:22:36.860704600Z"
}
},
"id": "55e802e0dbcc9c7f"
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [
{
"data": {
"text/plain": "Array([[4, 3],\n [2, 1]], dtype=int32)"
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jnp.rot90(a, k=2)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-06T11:12:48.186719Z",
"start_time": "2024-06-06T11:12:48.151161900Z"
}
},
"id": "16f8de3cadaa257a"
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [
{
"data": {
"text/plain": "Array([[2, 1],\n [4, 3]], dtype=int32)"
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# flip left-right\n",
"jnp.fliplr(a)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-06T11:14:28.668195300Z",
"start_time": "2024-06-06T11:14:28.631570500Z"
}
},
"id": "1fffa4e597ab5732"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
},
"id": "ca53c916dcff12ae"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}