add action_policy for problem;

This commit is contained in:
wls2002
2024-06-07 17:09:16 +08:00
parent 10ec1c2df9
commit 3d5b80c6fa
13 changed files with 2417 additions and 1191 deletions

View File

@@ -5,6 +5,11 @@ from algorithm.neat import *
from problem.rl_env import GymNaxEnv
def action_policy(forward_func, obs):
return jnp.argmax(forward_func(obs))
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
@@ -14,18 +19,15 @@ if __name__ == "__main__":
num_outputs=2,
max_nodes=50,
max_conns=100,
output_transform=lambda out: jnp.argmax(
out
), # the action of cartpole is {0, 1}
# output_transform=lambda out: jnp.argmax(
# out
# ), # the action of cartpole is {0, 1}
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name="CartPole-v1",
repeat_times=5
),
problem=GymNaxEnv(env_name="CartPole-v1", repeat_times=5, action_policy=action_policy),
generation_limit=10000,
fitness_target=500,
)