Files
tensorneat-mend/examples/brax/walker2d.py
2024-07-11 19:34:12 +08:00

52 lines
1.4 KiB
Python

from tensorneat.pipeline import Pipeline
from tensorneat.algorithm.neat import NEAT
from tensorneat.genome import DefaultGenome, BiasNode
from tensorneat.problem.rl import BraxEnv
from tensorneat.common import Act, Agg
import jax, jax.numpy as jnp
def random_sample_policy(randkey, obs):
return jax.random.uniform(randkey, (6,))
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
pop_size=1000,
species_size=20,
survival_threshold=0.1,
compatibility_threshold=1.0,
genome=DefaultGenome(
max_nodes=100,
max_conns=200,
num_inputs=17,
num_outputs=6,
init_hidden_layers=(),
node_gene=BiasNode(
activation_options=Act.tanh,
aggregation_options=Agg.sum,
),
output_transform=Act.standard_tanh,
),
),
problem=BraxEnv(
env_name="walker2d",
max_step=1000,
obs_normalization=True,
sample_episodes=1000,
sample_policy=random_sample_policy,
),
seed=42,
generation_limit=100,
fitness_target=5000,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)