diff --git a/examples/gymnax/cartpole_hyperneat.py b/examples/gymnax/cartpole_hyperneat.py new file mode 100644 index 0000000..792c380 --- /dev/null +++ b/examples/gymnax/cartpole_hyperneat.py @@ -0,0 +1,54 @@ +import jax.numpy as jnp + +from config import * +from pipeline_jitable_env import Pipeline +from algorithm import NEAT +from algorithm.neat.gene import NormalGene, NormalGeneConfig +from algorithm.hyperneat import HyperNEAT, NormalSubstrateConfig, NormalSubstrate +from problem.rl_env import GymNaxConfig, GymNaxEnv + + +def example_conf(): + return Config( + basic=BasicConfig( + seed=42, + fitness_target=500, + pop_size=10000 + ), + neat=NeatConfig( + inputs=4, + outputs=1, + ), + gene=NormalGeneConfig( + activation_default=Act.tanh, + activation_options=(Act.tanh,), + ), + hyperneat=HyperNeatConfig( + activation=Act.sigmoid, + inputs=4, + outputs=2 + ), + substrate=NormalSubstrateConfig( + input_coors=((-1, -1), (-0.5, -1), (0, -1), (0.5, -1), (1, -1)), + hidden_coors=( + # (-1, -0.5), (-0.5, -0.5), (0, -0.5), (0.5, -0.5), + (1, 0), (-1, 0), (-0.5, 0), (0, 0), (0.5, 0), (1, 0), + # (1, 0.5), (-1, 0.5), (-0.5, 0.5), (0, 0.5), (0.5, 0.5), (1, 0.5), + ), + output_coors=((-1, 1), (1, 1)), + ), + problem=GymNaxConfig( + env_name='CartPole-v1', + output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1} + ) + ) + + +if __name__ == '__main__': + conf = example_conf() + + algorithm = HyperNEAT(conf, NormalGene, NormalSubstrate) + pipeline = Pipeline(conf, algorithm, GymNaxEnv) + state = pipeline.setup() + pipeline.pre_compile(state) + state, best = pipeline.auto_run(state)