From 3cb5fbf581298f896c410fa71a8d9670fc9f3149 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 11 Jul 2024 15:08:02 +0800 Subject: [PATCH] update hyperneat and related examples --- examples/brax/walker.py | 60 ----------------- examples/brax/walker2d.py | 51 +++++++++++++++ examples/func_fit/xor_hyperneat.py | 54 +++++----------- tensorneat/algorithm/__init__.py | 1 + tensorneat/algorithm/hyperneat/hyperneat.py | 64 +++++++++---------- .../genome/operations/mutation/default.py | 6 +- tensorneat/pipeline.py | 2 +- 7 files changed, 102 insertions(+), 136 deletions(-) delete mode 100644 examples/brax/walker.py create mode 100644 examples/brax/walker2d.py diff --git a/examples/brax/walker.py b/examples/brax/walker.py deleted file mode 100644 index 38c3b0f..0000000 --- a/examples/brax/walker.py +++ /dev/null @@ -1,60 +0,0 @@ -from pipeline import Pipeline -from algorithm.neat import * - -from problem.rl_env import BraxEnv -from tensorneat.common import Act - -import jax, jax.numpy as jnp - - -def split_right_left(randkey, forward_func, obs): - right_obs_keys = jnp.array([2, 3, 4, 11, 12, 13]) - left_obs_keys = jnp.array([5, 6, 7, 14, 15, 16]) - right_action_keys = jnp.array([0, 1, 2]) - left_action_keys = jnp.array([3, 4, 5]) - - right_foot_obs = obs - left_foot_obs = obs - left_foot_obs = left_foot_obs.at[right_obs_keys].set(obs[left_obs_keys]) - left_foot_obs = left_foot_obs.at[left_obs_keys].set(obs[right_obs_keys]) - - right_action, left_action = jax.vmap(forward_func)(jnp.stack([right_foot_obs, left_foot_obs])) - # print(right_action.shape) - # print(left_action.shape) - - return jnp.concatenate([right_action, left_action]) - - -if __name__ == "__main__": - pipeline = Pipeline( - algorithm=NEAT( - species=DefaultSpecies( - genome=DefaultGenome( - num_inputs=17, - num_outputs=3, - max_nodes=50, - max_conns=100, - node_gene=DefaultNodeGene( - activation_options=(Act.tanh,), - activation_default=Act.tanh, - ), - output_transform=Act.tanh, - ), - pop_size=1000, - species_size=10, - ), - ), - problem=BraxEnv( - env_name="walker2d", - max_step=1000, - action_policy=split_right_left - ), - generation_limit=10000, - fitness_target=5000, - ) - - # initialize state - state = pipeline.setup() - # print(state) - # run until terminate - state, best = pipeline.auto_run(state) diff --git a/examples/brax/walker2d.py b/examples/brax/walker2d.py new file mode 100644 index 0000000..dc8680c --- /dev/null +++ b/examples/brax/walker2d.py @@ -0,0 +1,51 @@ +from tensorneat.pipeline import Pipeline +from tensorneat.algorithm.neat import NEAT +from tensorneat.genome import DefaultGenome, BiasNode + +from tensorneat.problem.rl_env import BraxEnv +from tensorneat.common import Act, Agg + +import jax, jax.numpy as jnp + + +def random_sample_policy(randkey, obs): + return jax.random.uniform(randkey, (6,)) + + +if __name__ == "__main__": + pipeline = Pipeline( + algorithm=NEAT( + pop_size=1000, + species_size=20, + survival_threshold=0.1, + compatibility_threshold=1.0, + genome=DefaultGenome( + max_nodes=100, + max_conns=200, + num_inputs=17, + num_outputs=6, + init_hidden_layers=(), + node_gene=BiasNode( + activation_options=Act.tanh, + aggregation_options=Agg.sum, + ), + output_transform=Act.standard_tanh, + ), + ), + problem=BraxEnv( + env_name="walker2d", + max_step=1000, + obs_normalization=True, + sample_episodes=1000, + sample_policy=random_sample_policy, + ), + seed=42, + generation_limit=100, + fitness_target=5000, + ) + + # initialize state + state = pipeline.setup() + # print(state) + # run until terminate + state, best = pipeline.auto_run(state) diff --git a/examples/func_fit/xor_hyperneat.py b/examples/func_fit/xor_hyperneat.py index 905e921..05c4a63 100644 --- a/examples/func_fit/xor_hyperneat.py +++ b/examples/func_fit/xor_hyperneat.py @@ -1,53 +1,33 @@ -from pipeline import Pipeline -from algorithm.neat import * -from algorithm.hyperneat import * +from tensorneat.pipeline import Pipeline +from tensorneat.algorithm.neat import NEAT +from tensorneat.algorithm.hyperneat import HyperNEAT, FullSubstrate +from tensorneat.genome import DefaultGenome from tensorneat.common import Act -from problem.func_fit import XOR3d +from tensorneat.problem.func_fit import XOR3d if __name__ == "__main__": pipeline = Pipeline( algorithm=HyperNEAT( substrate=FullSubstrate( - input_coors=[(-1, -1), (0.333, -1), (-0.333, -1), (1, -1)], # 3(XOR3d inputs) + 1(bias) - hidden_coors=[ - (-1, -0.5), (0.333, -0.5), (-0.333, -0.5), - (1, -0.5), - (-1, 0), - (0.333, 0), - (-0.333, 0), - (1, 0), - (-1, 0.5), - (0.333, 0.5), - (-0.333, 0.5), - (1, 0.5), - ], - output_coors=[ - (0, 1), # one output - ], + input_coors=((-1, -1), (-0.33, -1), (0.33, -1), (1, -1)), + hidden_coors=((-1, 0), (0, 0), (1, 0)), + output_coors=((0, 1),), ), neat=NEAT( - species=DefaultSpecies( - genome=DefaultGenome( - num_inputs=4, # [*coor1, *coor2] - num_outputs=1, # the weight of connection between two coor1 and coor2 - max_nodes=50, - max_conns=100, - node_gene=DefaultNodeGene( - activation_default=Act.tanh, - activation_options=(Act.tanh,), - ), - output_transform=Act.tanh, # the activation function for output node in NEAT - ), - pop_size=1000, - species_size=10, - compatibility_threshold=2, - survival_threshold=0.03, + pop_size=10000, + species_size=20, + survival_threshold=0.01, + genome=DefaultGenome( + num_inputs=4, # size of query coors + num_outputs=1, + init_hidden_layers=(), + output_transform=Act.standard_tanh, ), ), activation=Act.tanh, activate_time=10, - output_transform=Act.sigmoid, # the activation function for output node in HyperNEAT + output_transform=Act.standard_sigmoid, ), problem=XOR3d(), generation_limit=300, diff --git a/tensorneat/algorithm/__init__.py b/tensorneat/algorithm/__init__.py index deaf74c..8326df3 100644 --- a/tensorneat/algorithm/__init__.py +++ b/tensorneat/algorithm/__init__.py @@ -1,2 +1,3 @@ from .base import BaseAlgorithm from .neat import NEAT +from .hyperneat import HyperNEAT diff --git a/tensorneat/algorithm/hyperneat/hyperneat.py b/tensorneat/algorithm/hyperneat/hyperneat.py index b49d396..e58eb21 100644 --- a/tensorneat/algorithm/hyperneat/hyperneat.py +++ b/tensorneat/algorithm/hyperneat/hyperneat.py @@ -1,12 +1,12 @@ from typing import Callable -import jax, jax.numpy as jnp +import jax +from jax import vmap, numpy as jnp -from tensorneat.common import State, Act, Agg -from .. import BaseAlgorithm, NEAT -from ..neat.gene import BaseNodeGene, BaseConnGene -from ..neat.genome import RecurrentGenome from .substrate import * +from tensorneat.common import State, Act, Agg +from tensorneat.algorithm import BaseAlgorithm, NEAT +from tensorneat.genome import BaseNode, BaseConn, RecurrentGenome class HyperNEAT(BaseAlgorithm): @@ -14,64 +14,65 @@ class HyperNEAT(BaseAlgorithm): self, substrate: BaseSubstrate, neat: NEAT, - below_threshold: float = 0.3, + weight_threshold: float = 0.3, max_weight: float = 5.0, - aggregation=Agg.sum, - activation=Act.sigmoid, + aggregation: Callable = Agg.sum, + activation: Callable = Act.sigmoid, activate_time: int = 10, - output_transform: Callable = Act.sigmoid, + output_transform: Callable = Act.standard_sigmoid, ): assert ( substrate.query_coors.shape[1] == neat.num_inputs - ), "Substrate input size should be equal to NEAT input size" - + ), "Query coors of Substrate should be equal to NEAT input size" + self.substrate = substrate self.neat = neat - self.below_threshold = below_threshold + self.weight_threshold = weight_threshold self.max_weight = max_weight self.hyper_genome = RecurrentGenome( num_inputs=substrate.num_inputs, num_outputs=substrate.num_outputs, max_nodes=substrate.nodes_cnt, max_conns=substrate.conns_cnt, - node_gene=HyperNodeGene(aggregation, activation), - conn_gene=HyperNEATConnGene(), + node_gene=HyperNEATNode(aggregation, activation), + conn_gene=HyperNEATConn(), activate_time=activate_time, output_transform=output_transform, ) + self.pop_size = neat.pop_size def setup(self, state=State()): state = self.neat.setup(state) state = self.substrate.setup(state) return self.hyper_genome.setup(state) - def ask(self, state: State): + def ask(self, state): return self.neat.ask(state) - def tell(self, state: State, fitness): + def tell(self, state, fitness): state = self.neat.tell(state, fitness) return state def transform(self, state, individual): transformed = self.neat.transform(state, individual) - query_res = jax.vmap(self.neat.forward, in_axes=(None, None, 0))( + query_res = vmap(self.neat.forward, in_axes=(None, None, 0))( state, transformed, self.substrate.query_coors ) - # mute the connection with weight below threshold + # mute the connection with weight weight threshold query_res = jnp.where( - (-self.below_threshold < query_res) & (query_res < self.below_threshold), + (-self.weight_threshold < query_res) & (query_res < self.weight_threshold), 0.0, query_res, ) # make query res in range [-max_weight, max_weight] query_res = jnp.where( - query_res > 0, query_res - self.below_threshold, query_res + query_res > 0, query_res - self.weight_threshold, query_res ) query_res = jnp.where( - query_res < 0, query_res + self.below_threshold, query_res + query_res < 0, query_res + self.weight_threshold, query_res ) - query_res = query_res / (1 - self.below_threshold) * self.max_weight + query_res = query_res / (1 - self.weight_threshold) * self.max_weight h_nodes, h_conns = self.substrate.make_nodes( query_res @@ -79,11 +80,11 @@ class HyperNEAT(BaseAlgorithm): return self.hyper_genome.transform(state, h_nodes, h_conns) - def forward(self, state, inputs, transformed): + def forward(self, state, transformed, inputs): # add bias inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])]) - res = self.hyper_genome.forward(state, inputs_with_bias, transformed) + res = self.hyper_genome.forward(state, transformed, inputs_with_bias) return res @property @@ -94,18 +95,11 @@ class HyperNEAT(BaseAlgorithm): def num_outputs(self): return self.substrate.num_outputs - @property - def pop_size(self): - return self.neat.pop_size - - def member_count(self, state: State): - return self.neat.member_count(state) - - def generation(self, state: State): - return self.neat.generation(state) + def show_details(self, state, fitness): + return self.neat.show_details(state, fitness) -class HyperNodeGene(BaseNodeGene): +class HyperNEATNode(BaseNode): def __init__( self, aggregation=Agg.sum, @@ -123,7 +117,7 @@ class HyperNodeGene(BaseNodeGene): ) -class HyperNEATConnGene(BaseConnGene): +class HyperNEATConn(BaseConn): custom_attrs = ["weight"] def forward(self, state, attrs, inputs): diff --git a/tensorneat/genome/operations/mutation/default.py b/tensorneat/genome/operations/mutation/default.py index b369584..7a0f2b5 100644 --- a/tensorneat/genome/operations/mutation/default.py +++ b/tensorneat/genome/operations/mutation/default.py @@ -23,10 +23,10 @@ from ...utils import ( class DefaultMutation(BaseMutation): def __init__( self, - conn_add: float = 0.1, - conn_delete: float = 0, + conn_add: float = 0.2, + conn_delete: float = 0.2, node_add: float = 0.1, - node_delete: float = 0, + node_delete: float = 0.1, ): self.conn_add = conn_add self.conn_delete = conn_delete diff --git a/tensorneat/pipeline.py b/tensorneat/pipeline.py index 3de5e9f..7f13b62 100644 --- a/tensorneat/pipeline.py +++ b/tensorneat/pipeline.py @@ -34,7 +34,7 @@ class Pipeline(StatefulBaseClass): self.generation_limit = generation_limit self.pop_size = self.algorithm.pop_size - # print(self.problem.input_shape, self.problem.output_shape) + np.random.seed(self.seed) # TODO: make each algorithm's input_num and output_num assert (