From 5fa5e81c72066818c3db43e053dbbd374cd5529b Mon Sep 17 00:00:00 2001 From: wls2002 Date: Wed, 16 Apr 2025 10:06:23 +0800 Subject: [PATCH] add example for recurrent network --- examples/brax/hopper_recurrent.py | 40 +++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 examples/brax/hopper_recurrent.py diff --git a/examples/brax/hopper_recurrent.py b/examples/brax/hopper_recurrent.py new file mode 100644 index 0000000..1d88d90 --- /dev/null +++ b/examples/brax/hopper_recurrent.py @@ -0,0 +1,40 @@ +from tensorneat.pipeline import Pipeline +from tensorneat.algorithm.neat import NEAT +from tensorneat.genome import RecurrentGenome, BiasNode + +from tensorneat.problem.rl import BraxEnv +from tensorneat.common import ACT, AGG + +if __name__ == "__main__": + pipeline = Pipeline( + algorithm=NEAT( + pop_size=1000, + species_size=20, + survival_threshold=0.1, + compatibility_threshold=1.0, + genome=RecurrentGenome( + num_inputs=11, + num_outputs=3, + init_hidden_layers=(), + node_gene=BiasNode( + activation_options=ACT.tanh, + aggregation_options=AGG.sum, + ), + output_transform=ACT.tanh, + activate_time=20, + ), + ), + problem=BraxEnv( + env_name="hopper", + max_step=1000, + ), + seed=42, + generation_limit=100, + fitness_target=5000, + ) + + # initialize state + state = pipeline.setup() + # print(state) + # run until terminate + state, best = pipeline.auto_run(state)