diff --git a/tensorneat/algorithm/hyperneat/hyperneat.py b/tensorneat/algorithm/hyperneat/hyperneat.py index aa574ef..241fba1 100644 --- a/tensorneat/algorithm/hyperneat/hyperneat.py +++ b/tensorneat/algorithm/hyperneat/hyperneat.py @@ -40,19 +40,22 @@ class HyperNEAT(BaseAlgorithm): output_transform=output_transform, ) - def setup(self, randkey): - return State(neat_state=self.neat.setup(randkey)) + def setup(self, state=State()): + state = self.neat.setup(state) + state = self.substrate.setup(state) + return self.hyper_genome.setup(state) def ask(self, state: State): - return self.neat.ask(state.neat_state) + return self.neat.ask(state) def tell(self, state: State, fitness): - return state.update(neat_state=self.neat.tell(state.neat_state, fitness)) + state = self.neat.tell(state, fitness) + return state - def transform(self, individual): - transformed = self.neat.transform(individual) - query_res = jax.vmap(self.neat.forward, in_axes=(0, None))( - self.substrate.query_coors, transformed + def transform(self, state, individual): + transformed = self.neat.transform(state, individual) + query_res = jax.vmap(self.neat.forward, in_axes=(None, 0, None))( + state, self.substrate.query_coors, transformed ) # mute the connection with weight below threshold @@ -74,12 +77,12 @@ class HyperNEAT(BaseAlgorithm): h_nodes, h_conns = self.substrate.make_nodes( query_res ), self.substrate.make_conn(query_res) - return self.hyper_genome.transform(h_nodes, h_conns) + return self.hyper_genome.transform(state, h_nodes, h_conns) - def forward(self, inputs, transformed): + def forward(self, state, inputs, transformed): # add bias inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])]) - return self.hyper_genome.forward(inputs_with_bias, transformed) + return self.hyper_genome.forward(state, inputs_with_bias, transformed) @property def num_inputs(self): @@ -94,10 +97,10 @@ class HyperNEAT(BaseAlgorithm): return self.neat.pop_size def member_count(self, state: State): - return self.neat.member_count(state.neat_state) + return self.neat.member_count(state) def generation(self, state: State): - return self.neat.generation(state.neat_state) + return self.neat.generation(state) class HyperNodeGene(BaseNodeGene): @@ -110,7 +113,7 @@ class HyperNodeGene(BaseNodeGene): self.activation = activation self.aggregation = aggregation - def forward(self, attrs, inputs, is_output_node=False): + def forward(self, state, attrs, inputs, is_output_node=False): return jax.lax.cond( is_output_node, lambda: self.aggregation(inputs), # output node does not need activation @@ -121,6 +124,6 @@ class HyperNodeGene(BaseNodeGene): class HyperNEATConnGene(BaseConnGene): custom_attrs = ["weight"] - def forward(self, attrs, inputs): + def forward(self, state, attrs, inputs): weight = attrs[0] return inputs * weight diff --git a/tensorneat/algorithm/hyperneat/substrate/base.py b/tensorneat/algorithm/hyperneat/substrate/base.py index 8a60756..6172c8b 100644 --- a/tensorneat/algorithm/hyperneat/substrate/base.py +++ b/tensorneat/algorithm/hyperneat/substrate/base.py @@ -1,4 +1,10 @@ +from utils import State + + class BaseSubstrate: + def setup(self, state=State()): + return state + def make_nodes(self, query_res): raise NotImplementedError diff --git a/tensorneat/examples/brax/ant.py b/tensorneat/examples/brax/ant.py index faff804..3edddd0 100644 --- a/tensorneat/examples/brax/ant.py +++ b/tensorneat/examples/brax/ant.py @@ -11,8 +11,8 @@ if __name__ == "__main__": genome=DefaultGenome( num_inputs=27, num_outputs=8, - max_nodes=50, - max_conns=100, + max_nodes=100, + max_conns=200, node_gene=DefaultNodeGene( activation_options=(Act.tanh,), activation_default=Act.tanh, @@ -21,6 +21,8 @@ if __name__ == "__main__": ), pop_size=1000, species_size=10, + compatibility_threshold=3.5, + survival_threshold=0.01, ), ), problem=BraxEnv( diff --git a/tensorneat/examples/brax/half_cheetah.gif b/tensorneat/examples/brax/half_cheetah.gif deleted file mode 100644 index 3d2cab7..0000000 Binary files a/tensorneat/examples/brax/half_cheetah.gif and /dev/null differ diff --git a/tensorneat/examples/brax/half_cheetah.py b/tensorneat/examples/brax/half_cheetah.py index 5fa1ca6..bcc515a 100644 --- a/tensorneat/examples/brax/half_cheetah.py +++ b/tensorneat/examples/brax/half_cheetah.py @@ -17,6 +17,7 @@ if __name__ == "__main__": activation_options=(Act.tanh,), activation_default=Act.tanh, ), + output_transform=Act.tanh ), pop_size=1000, species_size=10, diff --git a/tensorneat/examples/brax/reacher.py b/tensorneat/examples/brax/reacher.py index c9a27aa..41d57c2 100644 --- a/tensorneat/examples/brax/reacher.py +++ b/tensorneat/examples/brax/reacher.py @@ -17,6 +17,7 @@ if __name__ == "__main__": activation_options=(Act.tanh,), activation_default=Act.tanh, ), + output_transform=Act.tanh, ), pop_size=100, species_size=10, diff --git a/tensorneat/examples/brax/walker.py b/tensorneat/examples/brax/walker.py index d128d21..1da6b31 100644 --- a/tensorneat/examples/brax/walker.py +++ b/tensorneat/examples/brax/walker.py @@ -17,6 +17,7 @@ if __name__ == "__main__": activation_options=(Act.tanh,), activation_default=Act.tanh, ), + output_transform=Act.tanh ), pop_size=10000, species_size=10, diff --git a/tensorneat/examples/func_fit/xor3d_hyperneat.py b/tensorneat/examples/func_fit/xor3d_hyperneat.py index 933d4aa..084a75d 100644 --- a/tensorneat/examples/func_fit/xor3d_hyperneat.py +++ b/tensorneat/examples/func_fit/xor3d_hyperneat.py @@ -9,11 +9,9 @@ if __name__ == "__main__": pipeline = Pipeline( algorithm=HyperNEAT( substrate=FullSubstrate( - input_coors=[(-1, -1), (0.333, -1), (-0.333, -1), (1, -1)], + input_coors=[(-1, -1), (0.333, -1), (-0.333, -1), (1, -1)], # 3(XOR3d inputs) + 1(bias) hidden_coors=[ - (-1, -0.5), - (0.333, -0.5), - (-0.333, -0.5), + (-1, -0.5), (0.333, -0.5), (-0.333, -0.5), (1, -0.5), (-1, 0), (0.333, 0), @@ -25,14 +23,14 @@ if __name__ == "__main__": (1, 0.5), ], output_coors=[ - (0, 1), + (0, 1), # one output ], ), neat=NEAT( species=DefaultSpecies( genome=DefaultGenome( - num_inputs=4, # [-1, -1, -1, 0] - num_outputs=1, + num_inputs=4, # [*coor1, *coor2] + num_outputs=1, # the weight of connection between two coor1 and coor2 max_nodes=50, max_conns=100, node_gene=DefaultNodeGene( diff --git a/tensorneat/examples/gymnax/cartpole_hyperneat.py b/tensorneat/examples/gymnax/cartpole_hyperneat.py index 200cec5..4302d5f 100644 --- a/tensorneat/examples/gymnax/cartpole_hyperneat.py +++ b/tensorneat/examples/gymnax/cartpole_hyperneat.py @@ -1,53 +1,74 @@ -import jax.numpy as jnp +import jax -from config import * from pipeline import Pipeline -from algorithm import NEAT -from algorithm.neat.gene import NormalGene, NormalGeneConfig -from algorithm.hyperneat import HyperNEAT, NormalSubstrateConfig, NormalSubstrate -from problem.rl_env import GymNaxConfig, GymNaxEnv - - -def example_conf(): - return Config( - basic=BasicConfig(seed=42, fitness_target=500, pop_size=10000), - neat=NeatConfig( - inputs=4, - outputs=1, - ), - gene=NormalGeneConfig( - activation_default=Act.tanh, - activation_options=(Act.tanh,), - ), - hyperneat=HyperNeatConfig(activation=Act.sigmoid, inputs=4, outputs=2), - substrate=NormalSubstrateConfig( - input_coors=((-1, -1), (-0.5, -1), (0, -1), (0.5, -1), (1, -1)), - hidden_coors=( - # (-1, -0.5), (-0.5, -0.5), (0, -0.5), (0.5, -0.5), - (1, 0), - (-1, 0), - (-0.5, 0), - (0, 0), - (0.5, 0), - (1, 0), - # (1, 0.5), (-1, 0.5), (-0.5, 0.5), (0, 0.5), (0.5, 0.5), (1, 0.5), - ), - output_coors=((-1, 1), (1, 1)), - ), - problem=GymNaxConfig( - env_name="CartPole-v1", - output_transform=lambda out: jnp.argmax( - out - ), # the action of cartpole is {0, 1} - ), - ) +from algorithm.neat import * +from algorithm.hyperneat import * +from utils import Act +from problem.rl_env import GymNaxEnv if __name__ == "__main__": - conf = example_conf() + pipeline = Pipeline( + algorithm=HyperNEAT( + substrate=FullSubstrate( + input_coors=[ + (-1, -1), + (-0.5, -1), + (0, -1), + (0.5, -1), + (1, -1), + ], # 4(problem inputs) + 1(bias) + hidden_coors=[ + (-1, -0.5), + (0.333, -0.5), + (-0.333, -0.5), + (1, -0.5), + (-1, 0), + (0.333, 0), + (-0.333, 0), + (1, 0), + (-1, 0.5), + (0.333, 0.5), + (-0.333, 0.5), + (1, 0.5), + ], + output_coors=[ + (-1, 1), + (1, 1), # one output + ], + ), + neat=NEAT( + species=DefaultSpecies( + genome=DefaultGenome( + num_inputs=4, # [*coor1, *coor2] + num_outputs=1, # the weight of connection between two coor1 and coor2 + max_nodes=50, + max_conns=100, + node_gene=DefaultNodeGene( + activation_default=Act.tanh, + activation_options=(Act.tanh,), + ), + output_transform=Act.tanh, # the activation function for output node in NEAT + ), + pop_size=10000, + species_size=10, + compatibility_threshold=3.5, + survival_threshold=0.03, + ), + ), + activation=Act.tanh, # the activation function for output node in HyperNEAT + activate_time=10, + output_transform=jax.numpy.argmax, # action of cartpole is in {0, 1} + ), + problem=GymNaxEnv( + env_name="CartPole-v1", + ), + generation_limit=300, + fitness_target=500, + ) - algorithm = HyperNEAT(conf, NormalGene, NormalSubstrate) - pipeline = Pipeline(conf, algorithm, GymNaxEnv) + # initialize state state = pipeline.setup() - pipeline.pre_compile(state) + # print(state) + # run until terminate state, best = pipeline.auto_run(state) diff --git a/tensorneat/examples/gymnax/mountain_car.py b/tensorneat/examples/gymnax/mountain_car.py index 38c4eb0..d1ea062 100644 --- a/tensorneat/examples/gymnax/mountain_car.py +++ b/tensorneat/examples/gymnax/mountain_car.py @@ -26,7 +26,7 @@ if __name__ == "__main__": env_name="MountainCar-v0", ), generation_limit=10000, - fitness_target=0, + fitness_target=-86, ) # initialize state diff --git a/tensorneat/examples/gymnax/mountain_car_continuous.py b/tensorneat/examples/gymnax/mountain_car_continuous.py index 1edd4f1..946d364 100644 --- a/tensorneat/examples/gymnax/mountain_car_continuous.py +++ b/tensorneat/examples/gymnax/mountain_car_continuous.py @@ -17,6 +17,7 @@ if __name__ == "__main__": activation_options=(Act.tanh,), activation_default=Act.tanh, ), + output_transform=Act.tanh ), pop_size=10000, species_size=10, @@ -26,7 +27,7 @@ if __name__ == "__main__": env_name="MountainCarContinuous-v0", ), generation_limit=10000, - fitness_target=500, + fitness_target=99, ) # initialize state diff --git a/tensorneat/examples/gymnax/pendulum.py b/tensorneat/examples/gymnax/pendulum.py index 9542c37..6f8d26e 100644 --- a/tensorneat/examples/gymnax/pendulum.py +++ b/tensorneat/examples/gymnax/pendulum.py @@ -17,7 +17,7 @@ if __name__ == "__main__": activation_options=(Act.tanh,), activation_default=Act.tanh, ), - output_transform=lambda out: out + output_transform=lambda out: Act.tanh(out) * 2, # the action of pendulum is [-2, 2] ), pop_size=10000, @@ -28,7 +28,7 @@ if __name__ == "__main__": env_name="Pendulum-v1", ), generation_limit=10000, - fitness_target=0, + fitness_target=-10, ) # initialize state diff --git a/tensorneat/examples/gymnax/reacher.py b/tensorneat/examples/gymnax/reacher.py index d6b5345..489357c 100644 --- a/tensorneat/examples/gymnax/reacher.py +++ b/tensorneat/examples/gymnax/reacher.py @@ -23,7 +23,7 @@ if __name__ == "__main__": env_name="Reacher-misc", ), generation_limit=10000, - fitness_target=500, + fitness_target=90, ) # initialize state diff --git a/tensorneat/problem/rl_env/brax_env.py b/tensorneat/problem/rl_env/brax_env.py index 7a65bfe..dcac0b4 100644 --- a/tensorneat/problem/rl_env/brax_env.py +++ b/tensorneat/problem/rl_env/brax_env.py @@ -5,8 +5,8 @@ from .rl_jit import RLEnv class BraxEnv(RLEnv): - def __init__(self, env_name: str = "ant", backend: str = "generalized"): - super().__init__() + def __init__(self, max_step=1000, env_name: str = "ant", backend: str = "generalized"): + super().__init__(max_step) self.env = envs.create(env_name=env_name, backend=backend) def env_step(self, randkey, env_state, action): diff --git a/tensorneat/problem/rl_env/gymnax_env.py b/tensorneat/problem/rl_env/gymnax_env.py index 95af8fa..e32814c 100644 --- a/tensorneat/problem/rl_env/gymnax_env.py +++ b/tensorneat/problem/rl_env/gymnax_env.py @@ -4,8 +4,8 @@ from .rl_jit import RLEnv class GymNaxEnv(RLEnv): - def __init__(self, env_name): - super().__init__() + def __init__(self, env_name, max_step=1000): + super().__init__(max_step) assert env_name in gymnax.registered_envs, f"Env {env_name} not registered" self.env, self.env_params = gymnax.make(env_name)