From ee8ec842025be0bc2bff2eee3674f0e34c4528a5 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 10 Jul 2024 11:24:11 +0800 Subject: [PATCH] odify genome for the official release --- README.md | 2 +- {tensorneat/examples => examples}/brax/ant.py | 2 +- .../brax/half_cheetah.py | 2 +- .../examples => examples}/brax/reacher.py | 2 +- .../examples => examples}/brax/show_test.py | 0 .../examples => examples}/brax/walker.py | 2 +- .../examples => examples}/func_fit/xor.py | 2 +- .../func_fit/xor3d_hyperneat.py | 2 +- .../func_fit/xor_recurrent.py | 0 .../examples => examples}/gymnax/arcbot.py | 0 .../examples => examples}/gymnax/cartpole.py | 0 .../gymnax/cartpole_hyperneat.py | 2 +- .../gymnax/mountain_car.py | 0 .../gymnax/mountain_car_continuous.py | 2 +- .../examples => examples}/gymnax/pendulum.py | 2 +- .../examples => examples}/gymnax/reacher.py | 0 .../interpret_visualize/genome_sympy.ipynb | 2 +- .../interpret_visualize/genome_sympy.py | 2 +- .../interpret_visualize/graph.svg | 0 .../interpret_visualize/network.json | 0 .../interpret_visualize/network.svg | 0 .../visualize_genome.ipynb | 2 +- .../interpret_visualize/visualize_genome.py | 0 .../jumanji/2048_random_policy.py | 0 .../jumanji/2048_test.ipynb | 2 +- .../jumanji/train_2048.py | 2 +- examples/tmp.py | 10 + .../with_evox/ray_test.py | 0 tensorneat/algorithm/base.py | 2 +- tensorneat/algorithm/hyperneat/hyperneat.py | 2 +- .../algorithm/hyperneat/substrate/base.py | 2 +- tensorneat/algorithm/neat/__init__.py | 1 - .../algorithm/neat/ga/crossover/base.py | 6 - tensorneat/algorithm/neat/ga/mutation/base.py | 6 - tensorneat/algorithm/neat/gene/base.py | 2 +- .../algorithm/neat/gene/conn/default.py | 2 +- .../algorithm/neat/gene/node/default.py | 2 +- .../gene/node/default_without_response.py | 2 +- .../algorithm/neat/gene/node/kan_node.py | 2 +- .../algorithm/neat/gene/node/min_max_node.py | 2 +- .../algorithm/neat/gene/node/normalized.py | 2 +- tensorneat/algorithm/neat/genome/__init__.py | 3 +- tensorneat/algorithm/neat/genome/base.py | 300 ++++++++---------- tensorneat/algorithm/neat/genome/default.py | 233 +++++++------- tensorneat/algorithm/neat/genome/dense.py | 56 ---- tensorneat/algorithm/neat/genome/hidden.py | 70 ---- .../{ga => genome/operations}/__init__.py | 1 + .../operations}/crossover/__init__.py | 0 .../neat/genome/operations/crossover/base.py | 12 + .../operations}/crossover/default.py | 27 +- .../genome/operations/distance/__init__.py | 2 + .../neat/genome/operations/distance/base.py | 15 + .../genome/operations/distance/default.py | 105 ++++++ .../operations}/mutation/__init__.py | 0 .../neat/genome/operations/mutation/base.py | 12 + .../operations}/mutation/default.py | 21 +- tensorneat/algorithm/neat/genome/recurrent.py | 20 +- tensorneat/algorithm/neat/genome/utils.py | 109 +++++++ tensorneat/algorithm/neat/neat.py | 2 +- tensorneat/algorithm/neat/species/base.py | 2 +- tensorneat/algorithm/neat/species/default.py | 10 +- tensorneat/{utils => common}/__init__.py | 2 +- .../{utils => common}/activation/__init__.py | 0 .../{utils => common}/activation/act_jnp.py | 0 .../{utils => common}/activation/act_sympy.py | 0 .../{utils => common}/aggregation/__init__.py | 0 .../{utils => common}/aggregation/agg_jnp.py | 0 .../aggregation/agg_sympy.py | 0 tensorneat/{utils => common}/graph.py | 0 tensorneat/{utils => common}/state.py | 0 .../{utils => common}/stateful_class.py | 0 tensorneat/{utils => common}/tools.py | 102 ------ tensorneat/pipeline.py | 2 +- tensorneat/problem/base.py | 2 +- tensorneat/problem/func_fit/func_fit.py | 2 +- .../problem/rl_env/jumanji/jumanji_2048.py | 2 +- tensorneat/problem/rl_env/rl_jit.py | 2 +- tensorneat/test/crossover_mutation.py | 2 +- tensorneat/test/nan_fitness.py | 2 +- tensorneat/test/test_kan.ipynb | 2 +- tensorneat/test/test_nan_fitness.py | 2 +- tensorneat/test/test_record_episode.ipynb | 2 +- tensorneat/test/test_update_by_batch.ipynb | 2 +- 83 files changed, 588 insertions(+), 611 deletions(-) rename {tensorneat/examples => examples}/brax/ant.py (96%) rename {tensorneat/examples => examples}/brax/half_cheetah.py (97%) rename {tensorneat/examples => examples}/brax/reacher.py (96%) rename {tensorneat/examples => examples}/brax/show_test.py (100%) rename {tensorneat/examples => examples}/brax/walker.py (98%) rename {tensorneat/examples => examples}/func_fit/xor.py (96%) rename {tensorneat/examples => examples}/func_fit/xor3d_hyperneat.py (98%) rename {tensorneat/examples => examples}/func_fit/xor_recurrent.py (100%) rename {tensorneat/examples => examples}/gymnax/arcbot.py (100%) rename {tensorneat/examples => examples}/gymnax/cartpole.py (100%) rename {tensorneat/examples => examples}/gymnax/cartpole_hyperneat.py (98%) rename {tensorneat/examples => examples}/gymnax/mountain_car.py (100%) rename {tensorneat/examples => examples}/gymnax/mountain_car_continuous.py (96%) rename {tensorneat/examples => examples}/gymnax/pendulum.py (96%) rename {tensorneat/examples => examples}/gymnax/reacher.py (100%) rename {tensorneat/examples => examples}/interpret_visualize/genome_sympy.ipynb (99%) rename {tensorneat/examples => examples}/interpret_visualize/genome_sympy.py (96%) rename {tensorneat/examples => examples}/interpret_visualize/graph.svg (100%) rename {tensorneat/examples => examples}/interpret_visualize/network.json (100%) rename {tensorneat/examples => examples}/interpret_visualize/network.svg (100%) rename {tensorneat/examples => examples}/interpret_visualize/visualize_genome.ipynb (99%) rename {tensorneat/examples => examples}/interpret_visualize/visualize_genome.py (100%) rename {tensorneat/examples => examples}/jumanji/2048_random_policy.py (100%) rename {tensorneat/examples => examples}/jumanji/2048_test.ipynb (99%) rename {tensorneat/examples => examples}/jumanji/train_2048.py (98%) create mode 100644 examples/tmp.py rename {tensorneat/examples => examples}/with_evox/ray_test.py (100%) delete mode 100644 tensorneat/algorithm/neat/ga/crossover/base.py delete mode 100644 tensorneat/algorithm/neat/ga/mutation/base.py delete mode 100644 tensorneat/algorithm/neat/genome/dense.py delete mode 100644 tensorneat/algorithm/neat/genome/hidden.py rename tensorneat/algorithm/neat/{ga => genome/operations}/__init__.py (67%) rename tensorneat/algorithm/neat/{ga => genome/operations}/crossover/__init__.py (100%) create mode 100644 tensorneat/algorithm/neat/genome/operations/crossover/base.py rename tensorneat/algorithm/neat/{ga => genome/operations}/crossover/default.py (76%) create mode 100644 tensorneat/algorithm/neat/genome/operations/distance/__init__.py create mode 100644 tensorneat/algorithm/neat/genome/operations/distance/base.py create mode 100644 tensorneat/algorithm/neat/genome/operations/distance/default.py rename tensorneat/algorithm/neat/{ga => genome/operations}/mutation/__init__.py (100%) create mode 100644 tensorneat/algorithm/neat/genome/operations/mutation/base.py rename tensorneat/algorithm/neat/{ga => genome/operations}/mutation/default.py (94%) create mode 100644 tensorneat/algorithm/neat/genome/utils.py rename tensorneat/{utils => common}/__init__.py (95%) rename tensorneat/{utils => common}/activation/__init__.py (100%) rename tensorneat/{utils => common}/activation/act_jnp.py (100%) rename tensorneat/{utils => common}/activation/act_sympy.py (100%) rename tensorneat/{utils => common}/aggregation/__init__.py (100%) rename tensorneat/{utils => common}/aggregation/agg_jnp.py (100%) rename tensorneat/{utils => common}/aggregation/agg_sympy.py (100%) rename tensorneat/{utils => common}/graph.py (100%) rename tensorneat/{utils => common}/state.py (100%) rename tensorneat/{utils => common}/stateful_class.py (100%) rename tensorneat/{utils => common}/tools.py (54%) diff --git a/README.md b/README.md index 3e527e2..07ee236 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,7 @@ from pipeline import Pipeline from algorithm.neat import * from problem.rl_env import BraxEnv -from utils import Act +from tensorneat.utils import Act if __name__ == '__main__': pipeline = Pipeline( diff --git a/tensorneat/examples/brax/ant.py b/examples/brax/ant.py similarity index 96% rename from tensorneat/examples/brax/ant.py rename to examples/brax/ant.py index 3edddd0..9b45e8c 100644 --- a/tensorneat/examples/brax/ant.py +++ b/examples/brax/ant.py @@ -2,7 +2,7 @@ from pipeline import Pipeline from algorithm.neat import * from problem.rl_env import BraxEnv -from utils import Act +from tensorneat.common import Act if __name__ == "__main__": pipeline = Pipeline( diff --git a/tensorneat/examples/brax/half_cheetah.py b/examples/brax/half_cheetah.py similarity index 97% rename from tensorneat/examples/brax/half_cheetah.py rename to examples/brax/half_cheetah.py index c23a941..330ed97 100644 --- a/tensorneat/examples/brax/half_cheetah.py +++ b/examples/brax/half_cheetah.py @@ -4,7 +4,7 @@ from pipeline import Pipeline from algorithm.neat import * from problem.rl_env import BraxEnv -from utils import Act +from tensorneat.common import Act def sample_policy(randkey, obs): diff --git a/tensorneat/examples/brax/reacher.py b/examples/brax/reacher.py similarity index 96% rename from tensorneat/examples/brax/reacher.py rename to examples/brax/reacher.py index 41d57c2..bb331e5 100644 --- a/tensorneat/examples/brax/reacher.py +++ b/examples/brax/reacher.py @@ -2,7 +2,7 @@ from pipeline import Pipeline from algorithm.neat import * from problem.rl_env import BraxEnv -from utils import Act +from tensorneat.common import Act if __name__ == "__main__": pipeline = Pipeline( diff --git a/tensorneat/examples/brax/show_test.py b/examples/brax/show_test.py similarity index 100% rename from tensorneat/examples/brax/show_test.py rename to examples/brax/show_test.py diff --git a/tensorneat/examples/brax/walker.py b/examples/brax/walker.py similarity index 98% rename from tensorneat/examples/brax/walker.py rename to examples/brax/walker.py index 3699593..38c3b0f 100644 --- a/tensorneat/examples/brax/walker.py +++ b/examples/brax/walker.py @@ -2,7 +2,7 @@ from pipeline import Pipeline from algorithm.neat import * from problem.rl_env import BraxEnv -from utils import Act +from tensorneat.common import Act import jax, jax.numpy as jnp diff --git a/tensorneat/examples/func_fit/xor.py b/examples/func_fit/xor.py similarity index 96% rename from tensorneat/examples/func_fit/xor.py rename to examples/func_fit/xor.py index e93326f..1ab61b8 100644 --- a/tensorneat/examples/func_fit/xor.py +++ b/examples/func_fit/xor.py @@ -2,7 +2,7 @@ from pipeline import Pipeline from algorithm.neat import * from problem.func_fit import XOR3d -from utils import ACT_ALL, AGG_ALL, Act, Agg +from tensorneat.common import ACT_ALL, AGG_ALL, Act, Agg if __name__ == "__main__": pipeline = Pipeline( diff --git a/tensorneat/examples/func_fit/xor3d_hyperneat.py b/examples/func_fit/xor3d_hyperneat.py similarity index 98% rename from tensorneat/examples/func_fit/xor3d_hyperneat.py rename to examples/func_fit/xor3d_hyperneat.py index 500e716..905e921 100644 --- a/tensorneat/examples/func_fit/xor3d_hyperneat.py +++ b/examples/func_fit/xor3d_hyperneat.py @@ -1,7 +1,7 @@ from pipeline import Pipeline from algorithm.neat import * from algorithm.hyperneat import * -from utils import Act +from tensorneat.common import Act from problem.func_fit import XOR3d diff --git a/tensorneat/examples/func_fit/xor_recurrent.py b/examples/func_fit/xor_recurrent.py similarity index 100% rename from tensorneat/examples/func_fit/xor_recurrent.py rename to examples/func_fit/xor_recurrent.py diff --git a/tensorneat/examples/gymnax/arcbot.py b/examples/gymnax/arcbot.py similarity index 100% rename from tensorneat/examples/gymnax/arcbot.py rename to examples/gymnax/arcbot.py diff --git a/tensorneat/examples/gymnax/cartpole.py b/examples/gymnax/cartpole.py similarity index 100% rename from tensorneat/examples/gymnax/cartpole.py rename to examples/gymnax/cartpole.py diff --git a/tensorneat/examples/gymnax/cartpole_hyperneat.py b/examples/gymnax/cartpole_hyperneat.py similarity index 98% rename from tensorneat/examples/gymnax/cartpole_hyperneat.py rename to examples/gymnax/cartpole_hyperneat.py index 4302d5f..30c92e2 100644 --- a/tensorneat/examples/gymnax/cartpole_hyperneat.py +++ b/examples/gymnax/cartpole_hyperneat.py @@ -3,7 +3,7 @@ import jax from pipeline import Pipeline from algorithm.neat import * from algorithm.hyperneat import * -from utils import Act +from tensorneat.common import Act from problem.rl_env import GymNaxEnv diff --git a/tensorneat/examples/gymnax/mountain_car.py b/examples/gymnax/mountain_car.py similarity index 100% rename from tensorneat/examples/gymnax/mountain_car.py rename to examples/gymnax/mountain_car.py diff --git a/tensorneat/examples/gymnax/mountain_car_continuous.py b/examples/gymnax/mountain_car_continuous.py similarity index 96% rename from tensorneat/examples/gymnax/mountain_car_continuous.py rename to examples/gymnax/mountain_car_continuous.py index 946d364..5420123 100644 --- a/tensorneat/examples/gymnax/mountain_car_continuous.py +++ b/examples/gymnax/mountain_car_continuous.py @@ -2,7 +2,7 @@ from pipeline import Pipeline from algorithm.neat import * from problem.rl_env import GymNaxEnv -from utils import Act +from tensorneat.common import Act if __name__ == "__main__": pipeline = Pipeline( diff --git a/tensorneat/examples/gymnax/pendulum.py b/examples/gymnax/pendulum.py similarity index 96% rename from tensorneat/examples/gymnax/pendulum.py rename to examples/gymnax/pendulum.py index 6f8d26e..d370394 100644 --- a/tensorneat/examples/gymnax/pendulum.py +++ b/examples/gymnax/pendulum.py @@ -2,7 +2,7 @@ from pipeline import Pipeline from algorithm.neat import * from problem.rl_env import GymNaxEnv -from utils import Act +from tensorneat.common import Act if __name__ == "__main__": pipeline = Pipeline( diff --git a/tensorneat/examples/gymnax/reacher.py b/examples/gymnax/reacher.py similarity index 100% rename from tensorneat/examples/gymnax/reacher.py rename to examples/gymnax/reacher.py diff --git a/tensorneat/examples/interpret_visualize/genome_sympy.ipynb b/examples/interpret_visualize/genome_sympy.ipynb similarity index 99% rename from tensorneat/examples/interpret_visualize/genome_sympy.ipynb rename to examples/interpret_visualize/genome_sympy.ipynb index 1c93aa0..2a7db4a 100644 --- a/tensorneat/examples/interpret_visualize/genome_sympy.ipynb +++ b/examples/interpret_visualize/genome_sympy.ipynb @@ -11,7 +11,7 @@ "from algorithm.neat.genome.advance import AdvanceInitialize\n", "from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n", "from utils.graph import topological_sort_python\n", - "from utils import Act, Agg\n", + "from tensorneat.utils import Act, Agg\n", "\n", "import numpy as np" ], diff --git a/tensorneat/examples/interpret_visualize/genome_sympy.py b/examples/interpret_visualize/genome_sympy.py similarity index 96% rename from tensorneat/examples/interpret_visualize/genome_sympy.py rename to examples/interpret_visualize/genome_sympy.py index 0b8363c..73d0c7f 100644 --- a/tensorneat/examples/interpret_visualize/genome_sympy.py +++ b/examples/interpret_visualize/genome_sympy.py @@ -3,7 +3,7 @@ import jax, jax.numpy as jnp from algorithm.neat import * from algorithm.neat.genome.dense import DenseInitialize from utils.graph import topological_sort_python -from utils import * +from tensorneat.common import * if __name__ == "__main__": genome = DenseInitialize( diff --git a/tensorneat/examples/interpret_visualize/graph.svg b/examples/interpret_visualize/graph.svg similarity index 100% rename from tensorneat/examples/interpret_visualize/graph.svg rename to examples/interpret_visualize/graph.svg diff --git a/tensorneat/examples/interpret_visualize/network.json b/examples/interpret_visualize/network.json similarity index 100% rename from tensorneat/examples/interpret_visualize/network.json rename to examples/interpret_visualize/network.json diff --git a/tensorneat/examples/interpret_visualize/network.svg b/examples/interpret_visualize/network.svg similarity index 100% rename from tensorneat/examples/interpret_visualize/network.svg rename to examples/interpret_visualize/network.svg diff --git a/tensorneat/examples/interpret_visualize/visualize_genome.ipynb b/examples/interpret_visualize/visualize_genome.ipynb similarity index 99% rename from tensorneat/examples/interpret_visualize/visualize_genome.ipynb rename to examples/interpret_visualize/visualize_genome.ipynb index e9aaa33..0e9573c 100644 --- a/tensorneat/examples/interpret_visualize/visualize_genome.ipynb +++ b/examples/interpret_visualize/visualize_genome.ipynb @@ -19,7 +19,7 @@ "from algorithm.neat.genome.advance import AdvanceInitialize\n", "from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n", "from utils.graph import topological_sort_python\n", - "from utils import Act, Agg\n", + "from tensorneat.utils import Act, Agg\n", "\n", "genome = AdvanceInitialize(\n", " num_inputs=16,\n", diff --git a/tensorneat/examples/interpret_visualize/visualize_genome.py b/examples/interpret_visualize/visualize_genome.py similarity index 100% rename from tensorneat/examples/interpret_visualize/visualize_genome.py rename to examples/interpret_visualize/visualize_genome.py diff --git a/tensorneat/examples/jumanji/2048_random_policy.py b/examples/jumanji/2048_random_policy.py similarity index 100% rename from tensorneat/examples/jumanji/2048_random_policy.py rename to examples/jumanji/2048_random_policy.py diff --git a/tensorneat/examples/jumanji/2048_test.ipynb b/examples/jumanji/2048_test.ipynb similarity index 99% rename from tensorneat/examples/jumanji/2048_test.ipynb rename to examples/jumanji/2048_test.ipynb index aec52d5..d6b22a3 100644 --- a/tensorneat/examples/jumanji/2048_test.ipynb +++ b/examples/jumanji/2048_test.ipynb @@ -29,7 +29,7 @@ "from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n", "\n", "from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048\n", - "from utils import Act, Agg\n", + "from tensorneat.utils import Act, Agg\n", "\n", "pipeline = Pipeline(\n", " algorithm=NEAT(\n", diff --git a/tensorneat/examples/jumanji/train_2048.py b/examples/jumanji/train_2048.py similarity index 98% rename from tensorneat/examples/jumanji/train_2048.py rename to examples/jumanji/train_2048.py index cca287d..e2efeb4 100644 --- a/tensorneat/examples/jumanji/train_2048.py +++ b/examples/jumanji/train_2048.py @@ -4,7 +4,7 @@ from pipeline import Pipeline from algorithm.neat import * from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048 -from utils import Act, Agg +from tensorneat.common import Act, Agg def rot_li(li): diff --git a/examples/tmp.py b/examples/tmp.py new file mode 100644 index 0000000..7033483 --- /dev/null +++ b/examples/tmp.py @@ -0,0 +1,10 @@ +import jax, jax.numpy as jnp + +from tensorneat.algorithm import NEAT +from tensorneat.algorithm.neat import DefaultGenome + +key = jax.random.key(0) +genome = DefaultGenome(num_inputs=5, num_outputs=3, init_hidden_layers=(1, )) +state = genome.setup() +nodes, conns = genome.initialize(state, key) +print(genome.repr(state, nodes, conns)) diff --git a/tensorneat/examples/with_evox/ray_test.py b/examples/with_evox/ray_test.py similarity index 100% rename from tensorneat/examples/with_evox/ray_test.py rename to examples/with_evox/ray_test.py diff --git a/tensorneat/algorithm/base.py b/tensorneat/algorithm/base.py index 7493be7..557493f 100644 --- a/tensorneat/algorithm/base.py +++ b/tensorneat/algorithm/base.py @@ -1,4 +1,4 @@ -from utils import State, StatefulBaseClass +from tensorneat.common import State, StatefulBaseClass class BaseAlgorithm(StatefulBaseClass): diff --git a/tensorneat/algorithm/hyperneat/hyperneat.py b/tensorneat/algorithm/hyperneat/hyperneat.py index dce9cab..b49d396 100644 --- a/tensorneat/algorithm/hyperneat/hyperneat.py +++ b/tensorneat/algorithm/hyperneat/hyperneat.py @@ -2,7 +2,7 @@ from typing import Callable import jax, jax.numpy as jnp -from utils import State, Act, Agg +from tensorneat.common import State, Act, Agg from .. import BaseAlgorithm, NEAT from ..neat.gene import BaseNodeGene, BaseConnGene from ..neat.genome import RecurrentGenome diff --git a/tensorneat/algorithm/hyperneat/substrate/base.py b/tensorneat/algorithm/hyperneat/substrate/base.py index 4a00925..4f2a074 100644 --- a/tensorneat/algorithm/hyperneat/substrate/base.py +++ b/tensorneat/algorithm/hyperneat/substrate/base.py @@ -1,4 +1,4 @@ -from utils import StatefulBaseClass +from tensorneat.common import StatefulBaseClass class BaseSubstrate(StatefulBaseClass): diff --git a/tensorneat/algorithm/neat/__init__.py b/tensorneat/algorithm/neat/__init__.py index 1338b5b..06f14d2 100644 --- a/tensorneat/algorithm/neat/__init__.py +++ b/tensorneat/algorithm/neat/__init__.py @@ -1,4 +1,3 @@ -from .ga import * from .gene import * from .genome import * from .species import * diff --git a/tensorneat/algorithm/neat/ga/crossover/base.py b/tensorneat/algorithm/neat/ga/crossover/base.py deleted file mode 100644 index d34921a..0000000 --- a/tensorneat/algorithm/neat/ga/crossover/base.py +++ /dev/null @@ -1,6 +0,0 @@ -from utils import StatefulBaseClass - - -class BaseCrossover(StatefulBaseClass): - def __call__(self, state, randkey, genome, nodes1, nodes2, conns1, conns2): - raise NotImplementedError diff --git a/tensorneat/algorithm/neat/ga/mutation/base.py b/tensorneat/algorithm/neat/ga/mutation/base.py deleted file mode 100644 index fa5cf73..0000000 --- a/tensorneat/algorithm/neat/ga/mutation/base.py +++ /dev/null @@ -1,6 +0,0 @@ -from utils import StatefulBaseClass - - -class BaseMutation(StatefulBaseClass): - def __call__(self, state, randkey, genome, nodes, conns, new_node_key): - raise NotImplementedError diff --git a/tensorneat/algorithm/neat/gene/base.py b/tensorneat/algorithm/neat/gene/base.py index c1d89a5..0625c88 100644 --- a/tensorneat/algorithm/neat/gene/base.py +++ b/tensorneat/algorithm/neat/gene/base.py @@ -1,5 +1,5 @@ import jax, jax.numpy as jnp -from utils import State, StatefulBaseClass, hash_array +from tensorneat.common import State, StatefulBaseClass, hash_array class BaseGene(StatefulBaseClass): diff --git a/tensorneat/algorithm/neat/gene/conn/default.py b/tensorneat/algorithm/neat/gene/conn/default.py index 263b9ac..21dbb55 100644 --- a/tensorneat/algorithm/neat/gene/conn/default.py +++ b/tensorneat/algorithm/neat/gene/conn/default.py @@ -2,7 +2,7 @@ import jax.numpy as jnp import jax.random import numpy as np import sympy as sp -from utils import mutate_float +from tensorneat.common import mutate_float from . import BaseConnGene diff --git a/tensorneat/algorithm/neat/gene/node/default.py b/tensorneat/algorithm/neat/gene/node/default.py index ba5fee9..25b4193 100644 --- a/tensorneat/algorithm/neat/gene/node/default.py +++ b/tensorneat/algorithm/neat/gene/node/default.py @@ -4,7 +4,7 @@ import numpy as np import jax, jax.numpy as jnp import sympy as sp -from utils import ( +from tensorneat.common import ( Act, Agg, act_func, diff --git a/tensorneat/algorithm/neat/gene/node/default_without_response.py b/tensorneat/algorithm/neat/gene/node/default_without_response.py index db2eba6..b7f9a0b 100644 --- a/tensorneat/algorithm/neat/gene/node/default_without_response.py +++ b/tensorneat/algorithm/neat/gene/node/default_without_response.py @@ -3,7 +3,7 @@ from typing import Tuple import jax, jax.numpy as jnp import numpy as np import sympy as sp -from utils import ( +from tensorneat.common import ( Act, Agg, act_func, diff --git a/tensorneat/algorithm/neat/gene/node/kan_node.py b/tensorneat/algorithm/neat/gene/node/kan_node.py index 300a6a2..298f888 100644 --- a/tensorneat/algorithm/neat/gene/node/kan_node.py +++ b/tensorneat/algorithm/neat/gene/node/kan_node.py @@ -1,6 +1,6 @@ import jax.numpy as jnp from . import BaseNodeGene -from utils import Agg +from tensorneat.common import Agg class KANNode(BaseNodeGene): diff --git a/tensorneat/algorithm/neat/gene/node/min_max_node.py b/tensorneat/algorithm/neat/gene/node/min_max_node.py index caf270c..9560f6c 100644 --- a/tensorneat/algorithm/neat/gene/node/min_max_node.py +++ b/tensorneat/algorithm/neat/gene/node/min_max_node.py @@ -2,7 +2,7 @@ from typing import Tuple import jax, jax.numpy as jnp -from utils import Act, Agg, act_func, agg_func, mutate_int, mutate_float +from tensorneat.common import Act, Agg, act_func, agg_func, mutate_int, mutate_float from . import BaseNodeGene diff --git a/tensorneat/algorithm/neat/gene/node/normalized.py b/tensorneat/algorithm/neat/gene/node/normalized.py index 717aeb8..342a14c 100644 --- a/tensorneat/algorithm/neat/gene/node/normalized.py +++ b/tensorneat/algorithm/neat/gene/node/normalized.py @@ -2,7 +2,7 @@ from typing import Tuple import jax, jax.numpy as jnp -from utils import Act, Agg, act_func, agg_func, mutate_int, mutate_float +from tensorneat.common import Act, Agg, act_func, agg_func, mutate_int, mutate_float from . import BaseNodeGene diff --git a/tensorneat/algorithm/neat/genome/__init__.py b/tensorneat/algorithm/neat/genome/__init__.py index 3f45f6c..5c10584 100644 --- a/tensorneat/algorithm/neat/genome/__init__.py +++ b/tensorneat/algorithm/neat/genome/__init__.py @@ -1,5 +1,4 @@ from .base import BaseGenome from .default import DefaultGenome from .recurrent import RecurrentGenome -from .hidden import HiddenInitialize -from .dense import DenseInitialize + diff --git a/tensorneat/algorithm/neat/genome/base.py b/tensorneat/algorithm/neat/genome/base.py index a73122d..e892f3f 100644 --- a/tensorneat/algorithm/neat/genome/base.py +++ b/tensorneat/algorithm/neat/genome/base.py @@ -1,8 +1,16 @@ +from typing import Callable, Sequence + import numpy as np -import jax, jax.numpy as jnp +import jax +from jax import vmap, numpy as jnp from ..gene import BaseNodeGene, BaseConnGene -from ..ga import BaseMutation, BaseCrossover -from utils import State, StatefulBaseClass, topological_sort_python, hash_array +from .operations import BaseMutation, BaseCrossover, BaseDistance +from tensorneat.common import ( + State, + StatefulBaseClass, + hash_array, +) +from .utils import valid_cnt class BaseGenome(StatefulBaseClass): @@ -18,120 +26,159 @@ class BaseGenome(StatefulBaseClass): conn_gene: BaseConnGene, mutation: BaseMutation, crossover: BaseCrossover, + distance: BaseDistance, + output_transform: Callable = None, + input_transform: Callable = None, + init_hidden_layers: Sequence[int] = (), ): + + # check transform functions + if input_transform is not None: + try: + _ = input_transform(jnp.zeros(num_inputs)) + except Exception as e: + raise ValueError(f"Output transform function failed: {e}") + + if output_transform is not None: + try: + _ = output_transform(jnp.zeros(num_outputs)) + except Exception as e: + raise ValueError(f"Output transform function failed: {e}") + + # prepare for initialization + all_layers = [num_inputs] + list(init_hidden_layers) + [num_outputs] + layer_indices = [] + next_index = 0 + for layer in all_layers: + layer_indices.append(list(range(next_index, next_index + layer))) + next_index += layer + + all_init_nodes = [] + all_init_conns_in_idx = [] + all_init_conns_out_idx = [] + for i in range(len(layer_indices) - 1): + in_layer = layer_indices[i] + out_layer = layer_indices[i + 1] + for in_idx in in_layer: + for out_idx in out_layer: + all_init_conns_in_idx.append(in_idx) + all_init_conns_out_idx.append(out_idx) + all_init_nodes.extend(in_layer) + + if max_nodes < len(all_init_nodes): + raise ValueError( + f"max_nodes={max_nodes} must be greater than or equal to the number of initial nodes={len(all_init_nodes)}" + ) + + if max_conns < len(all_init_conns_in_idx): + raise ValueError( + f"max_conns={max_conns} must be greater than or equal to the number of initial connections={len(all_init_conns_in_idx)}" + ) + self.num_inputs = num_inputs self.num_outputs = num_outputs - self.input_idx = np.arange(num_inputs) - self.output_idx = np.arange(num_inputs, num_inputs + num_outputs) self.max_nodes = max_nodes self.max_conns = max_conns self.node_gene = node_gene self.conn_gene = conn_gene self.mutation = mutation self.crossover = crossover + self.distance = distance + self.output_transform = output_transform + self.input_transform = input_transform + + self.input_idx = np.array(layer_indices[0]) + self.output_idx = np.array(layer_indices[-1]) + self.all_init_nodes = np.array(all_init_nodes) + self.all_init_conns = np.c_[all_init_conns_in_idx, all_init_conns_out_idx] def setup(self, state=State()): state = self.node_gene.setup(state) state = self.conn_gene.setup(state) - state = self.mutation.setup(state) - state = self.crossover.setup(state) + state = self.mutation.setup(state, self) + state = self.crossover.setup(state, self) + state = self.distance.setup(state, self) return state def transform(self, state, nodes, conns): raise NotImplementedError - def restore(self, state, transformed): - raise NotImplementedError - def forward(self, state, transformed, inputs): raise NotImplementedError + def sympy_func(self): + raise NotImplementedError + + def visualize(self): + raise NotImplementedError + def execute_mutation(self, state, randkey, nodes, conns, new_node_key): - return self.mutation(state, randkey, self, nodes, conns, new_node_key) + return self.mutation(state, randkey, nodes, conns, new_node_key) def execute_crossover(self, state, randkey, nodes1, conns1, nodes2, conns2): - return self.crossover(state, randkey, self, nodes1, conns1, nodes2, conns2) + return self.crossover(state, randkey, nodes1, conns1, nodes2, conns2) + + def execute_distance(self, state, nodes1, conns1, nodes2, conns2): + return self.distance(state, nodes1, conns1, nodes2, conns2) def initialize(self, state, randkey): - """ - Default initialization method for the genome. - Add an extra hidden node. - Make all input nodes and output nodes connected to the hidden node. - All attributes will be initialized randomly using gene.new_random_attrs method. - - For example, a network with 2 inputs and 1 output, the structure will be: - nodes: - [ - [0, attrs0], # input node 0 - [1, attrs1], # input node 1 - [2, attrs2], # output node 0 - [3, attrs3], # hidden node - [NaN, NaN], # empty node - ] - conns: - [ - [0, 3, attrs0], # input node 0 -> hidden node - [1, 3, attrs1], # input node 1 -> hidden node - [3, 2, attrs2], # hidden node -> output node 0 - [NaN, NaN], - [NaN, NaN], - ] - """ - k1, k2 = jax.random.split(randkey) # k1 for nodes, k2 for conns + + all_nodes_cnt = len(self.all_init_nodes) + all_conns_cnt = len(self.all_init_conns) + # initialize nodes - new_node_key = ( - max([*self.input_idx, *self.output_idx]) + 1 - ) # the key for the hidden node - node_keys = jnp.concatenate( - [self.input_idx, self.output_idx, jnp.array([new_node_key])] - ) # the list of all node keys - - # initialize nodes and connections with NaN nodes = jnp.full((self.max_nodes, self.node_gene.length), jnp.nan) - conns = jnp.full((self.max_conns, self.conn_gene.length), jnp.nan) + # create node indices + node_indices = self.all_init_nodes + # create node attrs + rand_keys_n = jax.random.split(k1, num=all_nodes_cnt) + node_attr_func = vmap(self.node_gene.new_random_attrs, in_axes=(None, 0)) + node_attrs = node_attr_func(state, rand_keys_n) - # set keys for input nodes, output nodes and hidden node - nodes = nodes.at[node_keys, 0].set(node_keys) - - # generate random attributes for nodes - node_keys = jax.random.split(k1, len(node_keys)) - random_node_attrs = jax.vmap( - self.node_gene.new_random_attrs, in_axes=(None, 0) - )(state, node_keys) - nodes = nodes.at[: len(node_keys), 1:].set(random_node_attrs) + nodes = nodes.at[:all_nodes_cnt, 0].set(node_indices) # set node indices + nodes = nodes.at[:all_nodes_cnt, 1:].set(node_attrs) # set node attrs # initialize conns - # input-hidden connections - input_conns = jnp.c_[ - self.input_idx, jnp.full_like(self.input_idx, new_node_key) - ] - conns = conns.at[self.input_idx, :2].set(input_conns) # in-keys, out-keys + conns = jnp.full((self.max_conns, self.conn_gene.length), jnp.nan) + # create input and output indices + conn_indices = self.all_init_conns + # create conn attrs + rand_keys_c = jax.random.split(k2, num=all_conns_cnt) + conns_attr_func = jax.vmap( + self.conn_gene.new_random_attrs, + in_axes=( + None, + 0, + ), + ) + conns_attrs = conns_attr_func(state, rand_keys_c) - # output-hidden connections - output_conns = jnp.c_[ - jnp.full_like(self.output_idx, new_node_key), self.output_idx - ] - conns = conns.at[self.output_idx, :2].set(output_conns) # in-keys, out-keys - - conn_keys = jax.random.split(k2, num=len(self.input_idx) + len(self.output_idx)) - # generate random attributes for conns - random_conn_attrs = jax.vmap( - self.conn_gene.new_random_attrs, in_axes=(None, 0) - )(state, conn_keys) - conns = conns.at[: len(conn_keys), 2:].set(random_conn_attrs) + conns = conns.at[:all_conns_cnt, :2].set(conn_indices) # set conn indices + conns = conns.at[:all_conns_cnt, 2:].set(conns_attrs) # set conn attrs return nodes, conns - def update_by_batch(self, state, batch_input, transformed): - """ - Update the genome by a batch of data. - """ - raise NotImplementedError + def network_dict(self, state, nodes, conns): + return { + "nodes": self._get_node_dict(state, nodes), + "conns": self._get_conn_dict(state, conns), + } + + def get_input_idx(self): + return self.input_idx.tolist() + + def get_output_idx(self): + return self.output_idx.tolist() + + def hash(self, nodes, conns): + nodes_hashs = vmap(hash_array)(nodes) + conns_hashs = vmap(hash_array)(conns) + return hash_array(jnp.concatenate([nodes_hashs, conns_hashs])) def repr(self, state, nodes, conns, precision=2): nodes, conns = jax.device_get([nodes, conns]) - nodes_cnt, conns_cnt = self.valid_cnt(nodes), self.valid_cnt(conns) + nodes_cnt, conns_cnt = valid_cnt(nodes), valid_cnt(conns) s = f"{self.__class__.__name__}(nodes={nodes_cnt}, conns={conns_cnt}):\n" s += f"\tNodes:\n" for node in nodes: @@ -152,11 +199,7 @@ class BaseGenome(StatefulBaseClass): s += f"\t\t{self.conn_gene.repr(state, conn, precision=precision)}\n" return s - @classmethod - def valid_cnt(cls, arr): - return jnp.sum(~jnp.isnan(arr[:, 0])) - - def get_conn_dict(self, state, conns): + def _get_conn_dict(self, state, conns): conns = jax.device_get(conns) conn_dict = {} for conn in conns: @@ -167,7 +210,7 @@ class BaseGenome(StatefulBaseClass): conn_dict[(in_idx, out_idx)] = cd return conn_dict - def get_node_dict(self, state, nodes): + def _get_node_dict(self, state, nodes): nodes = jax.device_get(nodes) node_dict = {} for node in nodes: @@ -177,92 +220,3 @@ class BaseGenome(StatefulBaseClass): idx = nd["idx"] node_dict[idx] = nd return node_dict - - def network_dict(self, state, nodes, conns): - return { - "nodes": self.get_node_dict(state, nodes), - "conns": self.get_conn_dict(state, conns), - } - - def get_input_idx(self): - return self.input_idx.tolist() - - def get_output_idx(self): - return self.output_idx.tolist() - - def sympy_func(self, state, network, sympy_output_transform=None): - raise NotImplementedError - - def visualize( - self, - network, - rotate=0, - reverse_node_order=False, - size=(300, 300, 300), - color=("blue", "blue", "blue"), - save_path="network.svg", - save_dpi=800, - **kwargs, - ): - import networkx as nx - from matplotlib import pyplot as plt - - nodes_list = list(network["nodes"]) - conns_list = list(network["conns"]) - input_idx = self.get_input_idx() - output_idx = self.get_output_idx() - topo_order, topo_layers = topological_sort_python(nodes_list, conns_list) - node2layer = { - node: layer for layer, nodes in enumerate(topo_layers) for node in nodes - } - if reverse_node_order: - topo_order = topo_order[::-1] - - G = nx.DiGraph() - - if not isinstance(size, tuple): - size = (size, size, size) - if not isinstance(color, tuple): - color = (color, color, color) - - for node in topo_order: - if node in input_idx: - G.add_node(node, subset=node2layer[node], size=size[0], color=color[0]) - elif node in output_idx: - G.add_node(node, subset=node2layer[node], size=size[2], color=color[2]) - else: - G.add_node(node, subset=node2layer[node], size=size[1], color=color[1]) - - for conn in conns_list: - G.add_edge(conn[0], conn[1]) - pos = nx.multipartite_layout(G) - - def rotate_layout(pos, angle): - angle_rad = np.deg2rad(angle) - cos_angle, sin_angle = np.cos(angle_rad), np.sin(angle_rad) - rotated_pos = {} - for node, (x, y) in pos.items(): - rotated_pos[node] = ( - cos_angle * x - sin_angle * y, - sin_angle * x + cos_angle * y, - ) - return rotated_pos - - rotated_pos = rotate_layout(pos, rotate) - - node_sizes = [n["size"] for n in G.nodes.values()] - node_colors = [n["color"] for n in G.nodes.values()] - - nx.draw( - G, - pos=rotated_pos, - node_size=node_sizes, - node_color=node_colors, - **kwargs, - ) - plt.savefig(save_path, dpi=save_dpi) - - def hash(self, nodes, conns): - nodes_hashs = jax.vmap(hash_array)(nodes) - conns_hashs = jax.vmap(hash_array)(conns) - return hash_array(jnp.concatenate([nodes_hashs, conns_hashs])) diff --git a/tensorneat/algorithm/neat/genome/default.py b/tensorneat/algorithm/neat/genome/default.py index 8f61cfd..38dc38b 100644 --- a/tensorneat/algorithm/neat/genome/default.py +++ b/tensorneat/algorithm/neat/genome/default.py @@ -1,25 +1,23 @@ import warnings -from typing import Callable -import jax, jax.numpy as jnp +import jax +from jax import vmap, numpy as jnp import numpy as np import sympy as sp -from utils import ( - unflatten_conns, + +from . import BaseGenome +from ..gene import DefaultNodeGene, DefaultConnGene +from .operations import DefaultMutation, DefaultCrossover, DefaultDistance +from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs + +from tensorneat.common import ( topological_sort, topological_sort_python, I_INF, - extract_node_attrs, - extract_conn_attrs, - set_node_attrs, - set_conn_attrs, attach_with_inf, SYMPY_FUNCS_MODULE_NP, SYMPY_FUNCS_MODULE_JNP, ) -from . import BaseGenome -from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene -from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover class DefaultGenome(BaseGenome): @@ -31,15 +29,18 @@ class DefaultGenome(BaseGenome): self, num_inputs: int, num_outputs: int, - max_nodes=5, - max_conns=4, - node_gene: BaseNodeGene = DefaultNodeGene(), - conn_gene: BaseConnGene = DefaultConnGene(), - mutation: BaseMutation = DefaultMutation(), - crossover: BaseCrossover = DefaultCrossover(), - output_transform: Callable = None, - input_transform: Callable = None, + max_nodes=50, + max_conns=100, + node_gene=DefaultNodeGene(), + conn_gene=DefaultConnGene(), + mutation=DefaultMutation(), + crossover=DefaultCrossover(), + distance=DefaultDistance(), + output_transform=None, + input_transform=None, + init_hidden_layers=(), ): + super().__init__( num_inputs, num_outputs, @@ -49,22 +50,12 @@ class DefaultGenome(BaseGenome): conn_gene, mutation, crossover, + distance, + output_transform, + input_transform, + init_hidden_layers, ) - if input_transform is not None: - try: - _ = input_transform(np.zeros(num_inputs)) - except Exception as e: - raise ValueError(f"Output transform function failed: {e}") - self.input_transform = input_transform - - if output_transform is not None: - try: - _ = output_transform(np.zeros(num_outputs)) - except Exception as e: - raise ValueError(f"Output transform function failed: {e}") - self.output_transform = output_transform - def transform(self, state, nodes, conns): u_conns = unflatten_conns(nodes, conns) conn_exist = u_conns != I_INF @@ -73,10 +64,6 @@ class DefaultGenome(BaseGenome): return seqs, nodes, conns, u_conns - def restore(self, state, transformed): - seqs, nodes, conns, u_conns = transformed - return nodes, conns - def forward(self, state, transformed, inputs): if self.input_transform is not None: @@ -86,8 +73,8 @@ class DefaultGenome(BaseGenome): ini_vals = jnp.full((self.max_nodes,), jnp.nan) ini_vals = ini_vals.at[self.input_idx].set(inputs) - nodes_attrs = jax.vmap(extract_node_attrs)(nodes) - conns_attrs = jax.vmap(extract_conn_attrs)(conns) + nodes_attrs = vmap(extract_node_attrs)(nodes) + conns_attrs = vmap(extract_conn_attrs)(conns) def cond_fun(carry): values, idx = carry @@ -105,7 +92,7 @@ class DefaultGenome(BaseGenome): def otherwise(): conn_indices = u_conns[:, i] hit_attrs = attach_with_inf(conns_attrs, conn_indices) - ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 0, 0))( + ins = vmap(self.conn_gene.forward, in_axes=(None, 0, 0))( state, hit_attrs, values ) @@ -130,85 +117,14 @@ class DefaultGenome(BaseGenome): else: return self.output_transform(vals[self.output_idx]) - def update_by_batch(self, state, batch_input, transformed): - - if self.input_transform is not None: - batch_input = jax.vmap(self.input_transform)(batch_input) - - cal_seqs, nodes, conns, u_conns = transformed - - batch_size = batch_input.shape[0] - batch_ini_vals = jnp.full((batch_size, self.max_nodes), jnp.nan) - batch_ini_vals = batch_ini_vals.at[:, self.input_idx].set(batch_input) - nodes_attrs = jax.vmap(extract_node_attrs)(nodes) - conns_attrs = jax.vmap(extract_conn_attrs)(conns) - - def cond_fun(carry): - batch_values, nodes_attrs_, conns_attrs_, idx = carry - return (idx < self.max_nodes) & (cal_seqs[idx] != I_INF) - - def body_func(carry): - batch_values, nodes_attrs_, conns_attrs_, idx = carry - i = cal_seqs[idx] - - def input_node(): - batch, new_attrs = self.node_gene.update_input_transform( - state, nodes_attrs_[i], batch_values[:, i] - ) - return ( - batch_values.at[:, i].set(batch), - nodes_attrs_.at[i].set(new_attrs), - conns_attrs_, - ) - - def otherwise(): - - conn_indices = u_conns[:, i] - hit_attrs = attach_with_inf(conns_attrs, conn_indices) - batch_ins, new_conn_attrs = jax.vmap( - self.conn_gene.update_by_batch, - in_axes=(None, 0, 1), - out_axes=(1, 0), - )(state, hit_attrs, batch_values) - - batch_z, new_node_attrs = self.node_gene.update_by_batch( - state, - nodes_attrs_[i], - batch_ins, - is_output_node=jnp.isin(i, self.output_idx), - ) - - return ( - batch_values.at[:, i].set(batch_z), - nodes_attrs_.at[i].set(new_node_attrs), - conns_attrs_.at[conn_indices].set(new_conn_attrs), - ) - - # the val of input nodes is obtained by the task, not by calculation - (batch_values, nodes_attrs_, conns_attrs_) = jax.lax.cond( - jnp.isin(i, self.input_idx), - input_node, - otherwise, - ) - - return batch_values, nodes_attrs_, conns_attrs_, idx + 1 - - batch_vals, nodes_attrs, conns_attrs, _ = jax.lax.while_loop( - cond_fun, body_func, (batch_ini_vals, nodes_attrs, conns_attrs, 0) + def network_dict(self, state, nodes, conns): + network = super().network_dict(state, nodes, conns) + topo_order, topo_layers = topological_sort_python( + set(network["nodes"]), set(network["conns"]) ) - - nodes = jax.vmap(set_node_attrs)(nodes, nodes_attrs) - conns = jax.vmap(set_conn_attrs)(conns, conns_attrs) - - new_transformed = (cal_seqs, nodes, conns, u_conns) - - if self.output_transform is None: - return batch_vals[:, self.output_idx], new_transformed - else: - return ( - jax.vmap(self.output_transform)(batch_vals[:, self.output_idx]), - new_transformed, - ) + network["topo_order"] = topo_order + network["topo_layers"] = topo_layers + return network def sympy_func( self, @@ -241,7 +157,8 @@ class DefaultGenome(BaseGenome): input_idx = self.get_input_idx() output_idx = self.get_output_idx() - order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"])) + order = network["topo_order"] + hidden_idx = [ i for i in network["nodes"] if i not in input_idx and i not in output_idx ] @@ -260,8 +177,12 @@ class DefaultGenome(BaseGenome): for i in order: if i in input_idx: - nodes_exprs[symbols[-i - 1]] = symbols[-i - 1] # origin equal to its symbol - nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](symbols[-i - 1]) # normed i + nodes_exprs[symbols[-i - 1]] = symbols[ + -i - 1 + ] # origin equal to its symbol + nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)]( + symbols[-i - 1] + ) # normed i else: in_conns = [c for c in network["conns"] if c[1] == i] @@ -325,3 +246,73 @@ class DefaultGenome(BaseGenome): output_exprs, forward_func, ) + + def visualize( + self, + network, + rotate=0, + reverse_node_order=False, + size=(300, 300, 300), + color=("blue", "blue", "blue"), + save_path="network.svg", + save_dpi=800, + **kwargs, + ): + import networkx as nx + from matplotlib import pyplot as plt + + nodes_list = list(network["nodes"]) + conns_list = list(network["conns"]) + input_idx = self.get_input_idx() + output_idx = self.get_output_idx() + + topo_order, topo_layers = network["topo_order"], network["topo_layers"] + node2layer = { + node: layer for layer, nodes in enumerate(topo_layers) for node in nodes + } + if reverse_node_order: + topo_order = topo_order[::-1] + + G = nx.DiGraph() + + if not isinstance(size, tuple): + size = (size, size, size) + if not isinstance(color, tuple): + color = (color, color, color) + + for node in topo_order: + if node in input_idx: + G.add_node(node, subset=node2layer[node], size=size[0], color=color[0]) + elif node in output_idx: + G.add_node(node, subset=node2layer[node], size=size[2], color=color[2]) + else: + G.add_node(node, subset=node2layer[node], size=size[1], color=color[1]) + + for conn in conns_list: + G.add_edge(conn[0], conn[1]) + pos = nx.multipartite_layout(G) + + def rotate_layout(pos, angle): + angle_rad = np.deg2rad(angle) + cos_angle, sin_angle = np.cos(angle_rad), np.sin(angle_rad) + rotated_pos = {} + for node, (x, y) in pos.items(): + rotated_pos[node] = ( + cos_angle * x - sin_angle * y, + sin_angle * x + cos_angle * y, + ) + return rotated_pos + + rotated_pos = rotate_layout(pos, rotate) + + node_sizes = [n["size"] for n in G.nodes.values()] + node_colors = [n["color"] for n in G.nodes.values()] + + nx.draw( + G, + pos=rotated_pos, + node_size=node_sizes, + node_color=node_colors, + **kwargs, + ) + plt.savefig(save_path, dpi=save_dpi) diff --git a/tensorneat/algorithm/neat/genome/dense.py b/tensorneat/algorithm/neat/genome/dense.py deleted file mode 100644 index 1c47ef6..0000000 --- a/tensorneat/algorithm/neat/genome/dense.py +++ /dev/null @@ -1,56 +0,0 @@ -import jax, jax.numpy as jnp -from .default import DefaultGenome - - -class DenseInitialize(DefaultGenome): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - assert self.max_nodes >= self.num_inputs + self.num_outputs - assert self.max_conns >= self.num_inputs * self.num_outputs - - def initialize(self, state, randkey): - - k1, k2 = jax.random.split(randkey, num=2) - - input_idx, output_idx = self.input_idx, self.output_idx - input_size = len(input_idx) - output_size = len(output_idx) - - nodes = jnp.full( - (self.max_nodes, self.node_gene.length), jnp.nan, dtype=jnp.float32 - ) - - nodes = nodes.at[input_idx, 0].set(input_idx) - nodes = nodes.at[output_idx, 0].set(output_idx) - - total_idx = input_size + output_size - rand_keys_n = jax.random.split(k1, num=total_idx) - - node_attr_func = jax.vmap(self.node_gene.new_random_attrs, in_axes=(None, 0)) - node_attrs = node_attr_func(state, rand_keys_n) - nodes = nodes.at[:total_idx, 1:].set(node_attrs) - - conns = jnp.full( - (self.max_conns, self.conn_gene.length), jnp.nan, dtype=jnp.float32 - ) - - input_to_output_ids, output_ids = jnp.meshgrid( - input_idx, output_idx, indexing="ij" - ) - total_conns = input_size * output_size - conns = conns.at[:total_conns, :2].set( - jnp.column_stack([input_to_output_ids.flatten(), output_ids.flatten()]) - ) - - rand_keys_c = jax.random.split(k2, num=total_conns) - conns_attr_func = jax.vmap( - self.conn_gene.new_random_attrs, - in_axes=( - None, - 0, - ), - ) - conns_attrs = conns_attr_func(state, rand_keys_c) - conns = conns.at[:total_conns, 2:].set(conns_attrs) - - return nodes, conns diff --git a/tensorneat/algorithm/neat/genome/hidden.py b/tensorneat/algorithm/neat/genome/hidden.py deleted file mode 100644 index 5ca82a0..0000000 --- a/tensorneat/algorithm/neat/genome/hidden.py +++ /dev/null @@ -1,70 +0,0 @@ -import jax, jax.numpy as jnp -from .default import DefaultGenome - - -class HiddenInitialize(DefaultGenome): - def __init__(self, hidden_cnt=8, *args, **kwargs): - super().__init__(*args, **kwargs) - self.hidden_cnt = hidden_cnt - - def initialize(self, state, randkey): - - k1, k2 = jax.random.split(randkey, num=2) - - input_idx, output_idx = self.input_idx, self.output_idx - input_size = len(input_idx) - output_size = len(output_idx) - - hidden_idx = jnp.arange( - input_size + output_size, input_size + output_size + self.hidden_cnt - ) - nodes = jnp.full( - (self.max_nodes, self.node_gene.length), jnp.nan, dtype=jnp.float32 - ) - - nodes = nodes.at[input_idx, 0].set(input_idx) - nodes = nodes.at[output_idx, 0].set(output_idx) - nodes = nodes.at[hidden_idx, 0].set(hidden_idx) - - total_idx = input_size + output_size + self.hidden_cnt - rand_keys_n = jax.random.split(k1, num=total_idx) - - node_attr_func = jax.vmap(self.node_gene.new_random_attrs, in_axes=(None, 0)) - node_attrs = node_attr_func(state, rand_keys_n) - nodes = nodes.at[:total_idx, 1:].set(node_attrs) - - conns = jnp.full( - (self.max_conns, self.conn_gene.length), jnp.nan, dtype=jnp.float32 - ) - - input_to_hidden_ids, hidden_ids = jnp.meshgrid( - input_idx, hidden_idx, indexing="ij" - ) - total_input_to_hidden_conns = input_size * self.hidden_cnt - conns = conns.at[:total_input_to_hidden_conns, :2].set( - jnp.column_stack([input_to_hidden_ids.flatten(), hidden_ids.flatten()]) - ) - - hidden_to_output_ids, output_ids = jnp.meshgrid( - hidden_idx, output_idx, indexing="ij" - ) - total_hidden_to_output_conns = self.hidden_cnt * output_size - conns = conns.at[ - total_input_to_hidden_conns : total_input_to_hidden_conns - + total_hidden_to_output_conns, - :2, - ].set(jnp.column_stack([hidden_to_output_ids.flatten(), output_ids.flatten()])) - - total_conns = total_input_to_hidden_conns + total_hidden_to_output_conns - rand_keys_c = jax.random.split(k2, num=total_conns) - conns_attr_func = jax.vmap( - self.conn_gene.new_random_attrs, - in_axes=( - None, - 0, - ), - ) - conns_attrs = conns_attr_func(state, rand_keys_c) - conns = conns.at[:total_conns, 2:].set(conns_attrs) - - return nodes, conns diff --git a/tensorneat/algorithm/neat/ga/__init__.py b/tensorneat/algorithm/neat/genome/operations/__init__.py similarity index 67% rename from tensorneat/algorithm/neat/ga/__init__.py rename to tensorneat/algorithm/neat/genome/operations/__init__.py index 198f8ac..a7abd68 100644 --- a/tensorneat/algorithm/neat/ga/__init__.py +++ b/tensorneat/algorithm/neat/genome/operations/__init__.py @@ -1,2 +1,3 @@ from .crossover import BaseCrossover, DefaultCrossover from .mutation import BaseMutation, DefaultMutation +from .distance import BaseDistance, DefaultDistance diff --git a/tensorneat/algorithm/neat/ga/crossover/__init__.py b/tensorneat/algorithm/neat/genome/operations/crossover/__init__.py similarity index 100% rename from tensorneat/algorithm/neat/ga/crossover/__init__.py rename to tensorneat/algorithm/neat/genome/operations/crossover/__init__.py diff --git a/tensorneat/algorithm/neat/genome/operations/crossover/base.py b/tensorneat/algorithm/neat/genome/operations/crossover/base.py new file mode 100644 index 0000000..143d519 --- /dev/null +++ b/tensorneat/algorithm/neat/genome/operations/crossover/base.py @@ -0,0 +1,12 @@ +from tensorneat.common import StatefulBaseClass, State + + +class BaseCrossover(StatefulBaseClass): + + def setup(self, state=State(), genome = None): + assert genome is not None, "genome should not be None" + self.genome = genome + return state + + def __call__(self, state, randkey, nodes1, nodes2, conns1, conns2): + raise NotImplementedError diff --git a/tensorneat/algorithm/neat/ga/crossover/default.py b/tensorneat/algorithm/neat/genome/operations/crossover/default.py similarity index 76% rename from tensorneat/algorithm/neat/ga/crossover/default.py rename to tensorneat/algorithm/neat/genome/operations/crossover/default.py index f55203c..1548152 100644 --- a/tensorneat/algorithm/neat/ga/crossover/default.py +++ b/tensorneat/algorithm/neat/genome/operations/crossover/default.py @@ -1,7 +1,8 @@ -import jax, jax.numpy as jnp +import jax +from jax import vmap, numpy as jnp from .base import BaseCrossover -from utils.tools import ( +from ...utils import ( extract_node_attrs, extract_conn_attrs, set_node_attrs, @@ -10,14 +11,14 @@ from utils.tools import ( class DefaultCrossover(BaseCrossover): - def __call__(self, state, randkey, genome, nodes1, conns1, nodes2, conns2): + def __call__(self, state, randkey, nodes1, conns1, nodes2, conns2): """ use genome1 and genome2 to generate a new genome notice that genome1 should have higher fitness than genome2 (genome1 is winner!) """ randkey1, randkey2 = jax.random.split(randkey, 2) - randkeys1 = jax.random.split(randkey1, genome.max_nodes) - randkeys2 = jax.random.split(randkey2, genome.max_conns) + randkeys1 = jax.random.split(randkey1, self.genome.max_nodes) + randkeys2 = jax.random.split(randkey2, self.genome.max_conns) # crossover nodes keys1, keys2 = nodes1[:, 0], nodes2[:, 0] @@ -26,33 +27,33 @@ class DefaultCrossover(BaseCrossover): # For not homologous genes, use the value of nodes1(winner) # For homologous genes, use the crossover result between nodes1 and nodes2 - node_attrs1 = jax.vmap(extract_node_attrs)(nodes1) - node_attrs2 = jax.vmap(extract_node_attrs)(nodes2) + node_attrs1 = vmap(extract_node_attrs)(nodes1) + node_attrs2 = vmap(extract_node_attrs)(nodes2) new_node_attrs = jnp.where( jnp.isnan(node_attrs1) | jnp.isnan(node_attrs2), # one of them is nan node_attrs1, # not homologous genes or both nan, use the value of nodes1(winner) - jax.vmap(genome.node_gene.crossover, in_axes=(None, 0, 0, 0))( + vmap(self.genome.node_gene.crossover, in_axes=(None, 0, 0, 0))( state, randkeys1, node_attrs1, node_attrs2 ), # homologous or both nan ) - new_nodes = jax.vmap(set_node_attrs)(nodes1, new_node_attrs) + new_nodes = vmap(set_node_attrs)(nodes1, new_node_attrs) # crossover connections con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2] conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True) - conns_attrs1 = jax.vmap(extract_conn_attrs)(conns1) - conns_attrs2 = jax.vmap(extract_conn_attrs)(conns2) + conns_attrs1 = vmap(extract_conn_attrs)(conns1) + conns_attrs2 = vmap(extract_conn_attrs)(conns2) new_conn_attrs = jnp.where( jnp.isnan(conns_attrs1) | jnp.isnan(conns_attrs2), conns_attrs1, # not homologous genes or both nan, use the value of conns1(winner) - jax.vmap(genome.conn_gene.crossover, in_axes=(None, 0, 0, 0))( + vmap(self.genome.conn_gene.crossover, in_axes=(None, 0, 0, 0))( state, randkeys2, conns_attrs1, conns_attrs2 ), # homologous or both nan ) - new_conns = jax.vmap(set_conn_attrs)(conns1, new_conn_attrs) + new_conns = vmap(set_conn_attrs)(conns1, new_conn_attrs) return new_nodes, new_conns diff --git a/tensorneat/algorithm/neat/genome/operations/distance/__init__.py b/tensorneat/algorithm/neat/genome/operations/distance/__init__.py new file mode 100644 index 0000000..7c78fdc --- /dev/null +++ b/tensorneat/algorithm/neat/genome/operations/distance/__init__.py @@ -0,0 +1,2 @@ +from .base import BaseDistance +from .default import DefaultDistance diff --git a/tensorneat/algorithm/neat/genome/operations/distance/base.py b/tensorneat/algorithm/neat/genome/operations/distance/base.py new file mode 100644 index 0000000..f3f6a65 --- /dev/null +++ b/tensorneat/algorithm/neat/genome/operations/distance/base.py @@ -0,0 +1,15 @@ +from tensorneat.common import StatefulBaseClass, State + + +class BaseDistance(StatefulBaseClass): + + def setup(self, state=State(), genome = None): + assert genome is not None, "genome should not be None" + self.genome = genome + return state + + def __call__(self, state, nodes1, nodes2, conns1, conns2): + """ + The distance between two genomes + """ + raise NotImplementedError diff --git a/tensorneat/algorithm/neat/genome/operations/distance/default.py b/tensorneat/algorithm/neat/genome/operations/distance/default.py new file mode 100644 index 0000000..3e78916 --- /dev/null +++ b/tensorneat/algorithm/neat/genome/operations/distance/default.py @@ -0,0 +1,105 @@ +from jax import vmap, numpy as jnp + +from .base import BaseDistance +from ...utils import extract_node_attrs, extract_conn_attrs + + +class DefaultDistance(BaseDistance): + def __init__( + self, + compatibility_disjoint: float = 1.0, + compatibility_weight: float = 0.4, + ): + self.compatibility_disjoint = compatibility_disjoint + self.compatibility_weight = compatibility_weight + + def __call__(self, state, nodes1, nodes2, conns1, conns2): + """ + The distance between two genomes + """ + d = self.node_distance(state, nodes1, nodes2) + self.conn_distance( + state, conns1, conns2 + ) + return d + + def node_distance(self, state, nodes1, nodes2): + """ + The distance of the nodes part for two genomes + """ + node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0])) + node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0])) + max_cnt = jnp.maximum(node_cnt1, node_cnt2) + + # align homologous nodes + # this process is similar to np.intersect1d. + nodes = jnp.concatenate((nodes1, nodes2), axis=0) + keys = nodes[:, 0] + sorted_indices = jnp.argsort(keys, axis=0) + nodes = nodes[sorted_indices] + nodes = jnp.concatenate( + [nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0 + ) # add a nan row to the end + fr, sr = nodes[:-1], nodes[1:] # first row, second row + + # flag location of homologous nodes + intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0]) + + # calculate the count of non_homologous of two genomes + non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask) + + # calculate the distance of homologous nodes + fr_attrs = vmap(extract_node_attrs)(fr) + sr_attrs = vmap(extract_node_attrs)(sr) + hnd = vmap(self.genome.node_gene.distance, in_axes=(None, 0, 0))( + state, fr_attrs, sr_attrs + ) # homologous node distance + hnd = jnp.where(jnp.isnan(hnd), 0, hnd) + homologous_distance = jnp.sum(hnd * intersect_mask) + + val = ( + non_homologous_cnt * self.compatibility_disjoint + + homologous_distance * self.compatibility_weight + ) + + val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize + + return val + + def conn_distance(self, state, conns1, conns2): + """ + The distance of the conns part for two genomes + """ + con_cnt1 = jnp.sum(~jnp.isnan(conns1[:, 0])) + con_cnt2 = jnp.sum(~jnp.isnan(conns2[:, 0])) + max_cnt = jnp.maximum(con_cnt1, con_cnt2) + + cons = jnp.concatenate((conns1, conns2), axis=0) + keys = cons[:, :2] + sorted_indices = jnp.lexsort(keys.T[::-1]) + cons = cons[sorted_indices] + cons = jnp.concatenate( + [cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0 + ) # add a nan row to the end + fr, sr = cons[:-1], cons[1:] # first row, second row + + # both genome has such connection + intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0]) + + non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask) + + fr_attrs = vmap(extract_conn_attrs)(fr) + sr_attrs = vmap(extract_conn_attrs)(sr) + hcd = vmap(self.genome.conn_gene.distance, in_axes=(None, 0, 0))( + state, fr_attrs, sr_attrs + ) # homologous connection distance + hcd = jnp.where(jnp.isnan(hcd), 0, hcd) + homologous_distance = jnp.sum(hcd * intersect_mask) + + val = ( + non_homologous_cnt * self.compatibility_disjoint + + homologous_distance * self.compatibility_weight + ) + + val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize + + return val diff --git a/tensorneat/algorithm/neat/ga/mutation/__init__.py b/tensorneat/algorithm/neat/genome/operations/mutation/__init__.py similarity index 100% rename from tensorneat/algorithm/neat/ga/mutation/__init__.py rename to tensorneat/algorithm/neat/genome/operations/mutation/__init__.py diff --git a/tensorneat/algorithm/neat/genome/operations/mutation/base.py b/tensorneat/algorithm/neat/genome/operations/mutation/base.py new file mode 100644 index 0000000..15c0d4a --- /dev/null +++ b/tensorneat/algorithm/neat/genome/operations/mutation/base.py @@ -0,0 +1,12 @@ +from tensorneat.common import StatefulBaseClass, State + + +class BaseMutation(StatefulBaseClass): + + def setup(self, state=State(), genome = None): + assert genome is not None, "genome should not be None" + self.genome = genome + return state + + def __call__(self, state, randkey, genome, nodes, conns, new_node_key): + raise NotImplementedError diff --git a/tensorneat/algorithm/neat/ga/mutation/default.py b/tensorneat/algorithm/neat/genome/operations/mutation/default.py similarity index 94% rename from tensorneat/algorithm/neat/ga/mutation/default.py rename to tensorneat/algorithm/neat/genome/operations/mutation/default.py index 83f994e..e7100bc 100644 --- a/tensorneat/algorithm/neat/ga/mutation/default.py +++ b/tensorneat/algorithm/neat/genome/operations/mutation/default.py @@ -1,11 +1,14 @@ -import jax, jax.numpy as jnp +import jax +from jax import vmap, numpy as jnp from . import BaseMutation -from utils import ( +from tensorneat.common import ( fetch_first, fetch_random, I_INF, - unflatten_conns, check_cycles, +) +from ...utils import ( + unflatten_conns, add_node, add_conn, delete_node_by_pos, @@ -225,17 +228,17 @@ class DefaultMutation(BaseMutation): nodes_randkeys = jax.random.split(k1, num=genome.max_nodes) conns_randkeys = jax.random.split(k2, num=genome.max_conns) - node_attrs = jax.vmap(extract_node_attrs)(nodes) - new_node_attrs = jax.vmap(genome.node_gene.mutate, in_axes=(None, 0, 0))( + node_attrs = vmap(extract_node_attrs)(nodes) + new_node_attrs = vmap(genome.node_gene.mutate, in_axes=(None, 0, 0))( state, nodes_randkeys, node_attrs ) - new_nodes = jax.vmap(set_node_attrs)(nodes, new_node_attrs) + new_nodes = vmap(set_node_attrs)(nodes, new_node_attrs) - conn_attrs = jax.vmap(extract_conn_attrs)(conns) - new_conn_attrs = jax.vmap(genome.conn_gene.mutate, in_axes=(None, 0, 0))( + conn_attrs = vmap(extract_conn_attrs)(conns) + new_conn_attrs = vmap(genome.conn_gene.mutate, in_axes=(None, 0, 0))( state, conns_randkeys, conn_attrs ) - new_conns = jax.vmap(set_conn_attrs)(conns, new_conn_attrs) + new_conns = vmap(set_conn_attrs)(conns, new_conn_attrs) # nan nodes not changed new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes) diff --git a/tensorneat/algorithm/neat/genome/recurrent.py b/tensorneat/algorithm/neat/genome/recurrent.py index faed575..6509a99 100644 --- a/tensorneat/algorithm/neat/genome/recurrent.py +++ b/tensorneat/algorithm/neat/genome/recurrent.py @@ -1,11 +1,11 @@ from typing import Callable import jax, jax.numpy as jnp -from utils import unflatten_conns +from .utils import unflatten_conns from . import BaseGenome -from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene -from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover +from ..gene import DefaultNodeGene, DefaultConnGene +from .operations import DefaultMutation, DefaultCrossover class RecurrentGenome(BaseGenome): @@ -17,13 +17,13 @@ class RecurrentGenome(BaseGenome): self, num_inputs: int, num_outputs: int, - max_nodes: int, - max_conns: int, - node_gene: BaseNodeGene = DefaultNodeGene(), - conn_gene: BaseConnGene = DefaultConnGene(), - mutation: BaseMutation = DefaultMutation(), - crossover: BaseCrossover = DefaultCrossover(), - activate_time: int = 10, + max_nodes = 50, + max_conns = 100, + node_gene=DefaultNodeGene(), + conn_gene=DefaultConnGene(), + mutation=DefaultMutation(), + crossover=DefaultCrossover(), + activate_time=10, output_transform: Callable = None, ): super().__init__( diff --git a/tensorneat/algorithm/neat/genome/utils.py b/tensorneat/algorithm/neat/genome/utils.py new file mode 100644 index 0000000..822566d --- /dev/null +++ b/tensorneat/algorithm/neat/genome/utils.py @@ -0,0 +1,109 @@ +import jax +from jax import vmap, numpy as jnp + +from tensorneat.common import fetch_first, I_INF + + +def unflatten_conns(nodes, conns): + """ + transform the (C, CL) connections to (N, N), which contains the idx of the connection in conns + connection length, N means the number of nodes, C means the number of connections + returns the unflatten connection indices with shape (N, N) + """ + N = nodes.shape[0] # max_nodes + C = conns.shape[0] # max_conns + node_keys = nodes[:, 0] + i_keys, o_keys = conns[:, 0], conns[:, 1] + + def key_to_indices(key, keys): + return fetch_first(key == keys) + + i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys) + o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys) + + # Is interesting that jax use clip when attach data in array + # however, it will do nothing when setting values in an array + # put the index of connections in the unflatten array + unflatten = ( + jnp.full((N, N), I_INF, dtype=jnp.int32) + .at[i_idxs, o_idxs] + .set(jnp.arange(C, dtype=jnp.int32)) + ) + + return unflatten + + +def valid_cnt(nodes_or_conns): + return jnp.sum(~jnp.isnan(nodes_or_conns[:, 0])) + + +def extract_node_attrs(node): + """ + node: Array(NL, ) + extract the attributes of a node + """ + return node[1:] # 0 is for idx + + +def set_node_attrs(node, attrs): + """ + node: Array(NL, ) + attrs: Array(NL-1, ) + set the attributes of a node + """ + return node.at[1:].set(attrs) # 0 is for idx + + +def extract_conn_attrs(conn): + """ + conn: Array(CL, ) + extract the attributes of a connection + """ + return conn[2:] # 0, 1 is for in-idx and out-idx + + +def set_conn_attrs(conn, attrs): + """ + conn: Array(CL, ) + attrs: Array(CL-2, ) + set the attributes of a connection + """ + return conn.at[2:].set(attrs) # 0, 1 is for in-idx and out-idx + + +def add_node(nodes, new_key: int, attrs): + """ + Add a new node to the genome. + The new node will place at the first NaN row. + """ + exist_keys = nodes[:, 0] + pos = fetch_first(jnp.isnan(exist_keys)) + new_nodes = nodes.at[pos, 0].set(new_key) + return new_nodes.at[pos, 1:].set(attrs) + + +def delete_node_by_pos(nodes, pos): + """ + Delete a node from the genome. + Delete the node by its pos in nodes. + """ + return nodes.at[pos].set(jnp.nan) + + +def add_conn(conns, i_key, o_key, attrs): + """ + Add a new connection to the genome. + The new connection will place at the first NaN row. + """ + con_keys = conns[:, 0] + pos = fetch_first(jnp.isnan(con_keys)) + new_conns = conns.at[pos, 0:2].set(jnp.array([i_key, o_key])) + return new_conns.at[pos, 2:].set(attrs) + + +def delete_conn_by_pos(conns, pos): + """ + Delete a connection from the genome. + Delete the connection by its idx. + """ + return conns.at[pos].set(jnp.nan) diff --git a/tensorneat/algorithm/neat/neat.py b/tensorneat/algorithm/neat/neat.py index c12d49d..6c04d8e 100644 --- a/tensorneat/algorithm/neat/neat.py +++ b/tensorneat/algorithm/neat/neat.py @@ -1,5 +1,5 @@ import jax, jax.numpy as jnp -from utils import State +from tensorneat.common import State from .. import BaseAlgorithm from .species import * diff --git a/tensorneat/algorithm/neat/species/base.py b/tensorneat/algorithm/neat/species/base.py index f53b8a5..f6175de 100644 --- a/tensorneat/algorithm/neat/species/base.py +++ b/tensorneat/algorithm/neat/species/base.py @@ -1,4 +1,4 @@ -from utils import State, StatefulBaseClass +from tensorneat.common import State, StatefulBaseClass from ..genome import BaseGenome diff --git a/tensorneat/algorithm/neat/species/default.py b/tensorneat/algorithm/neat/species/default.py index 809a19d..310e070 100644 --- a/tensorneat/algorithm/neat/species/default.py +++ b/tensorneat/algorithm/neat/species/default.py @@ -1,9 +1,11 @@ import jax, jax.numpy as jnp -from utils import ( +from tensorneat.common import ( State, rank_elements, argmin_with_mask, fetch_first, +) +from ..genome.utils import ( extract_conn_attrs, extract_node_attrs, ) @@ -635,7 +637,9 @@ class DefaultSpecies(BaseSpecies): # find next node key all_nodes_keys = state.pop_nodes[:, :, 0] - max_node_key = jnp.max(all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0) + max_node_key = jnp.max( + all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0 + ) next_node_key = max_node_key + 1 new_node_keys = jnp.arange(self.pop_size) + next_node_key @@ -669,4 +673,4 @@ class DefaultSpecies(BaseSpecies): randkey=randkey, pop_nodes=pop_nodes, pop_conns=pop_conns, - ) \ No newline at end of file + ) diff --git a/tensorneat/utils/__init__.py b/tensorneat/common/__init__.py similarity index 95% rename from tensorneat/utils/__init__.py rename to tensorneat/common/__init__.py index f61b9bd..c67429f 100644 --- a/tensorneat/utils/__init__.py +++ b/tensorneat/common/__init__.py @@ -1,4 +1,4 @@ -from utils.aggregation.agg_jnp import Agg, agg_func, AGG_ALL +from tensorneat.common.aggregation.agg_jnp import Agg, agg_func, AGG_ALL from .tools import * from .graph import * from .state import State diff --git a/tensorneat/utils/activation/__init__.py b/tensorneat/common/activation/__init__.py similarity index 100% rename from tensorneat/utils/activation/__init__.py rename to tensorneat/common/activation/__init__.py diff --git a/tensorneat/utils/activation/act_jnp.py b/tensorneat/common/activation/act_jnp.py similarity index 100% rename from tensorneat/utils/activation/act_jnp.py rename to tensorneat/common/activation/act_jnp.py diff --git a/tensorneat/utils/activation/act_sympy.py b/tensorneat/common/activation/act_sympy.py similarity index 100% rename from tensorneat/utils/activation/act_sympy.py rename to tensorneat/common/activation/act_sympy.py diff --git a/tensorneat/utils/aggregation/__init__.py b/tensorneat/common/aggregation/__init__.py similarity index 100% rename from tensorneat/utils/aggregation/__init__.py rename to tensorneat/common/aggregation/__init__.py diff --git a/tensorneat/utils/aggregation/agg_jnp.py b/tensorneat/common/aggregation/agg_jnp.py similarity index 100% rename from tensorneat/utils/aggregation/agg_jnp.py rename to tensorneat/common/aggregation/agg_jnp.py diff --git a/tensorneat/utils/aggregation/agg_sympy.py b/tensorneat/common/aggregation/agg_sympy.py similarity index 100% rename from tensorneat/utils/aggregation/agg_sympy.py rename to tensorneat/common/aggregation/agg_sympy.py diff --git a/tensorneat/utils/graph.py b/tensorneat/common/graph.py similarity index 100% rename from tensorneat/utils/graph.py rename to tensorneat/common/graph.py diff --git a/tensorneat/utils/state.py b/tensorneat/common/state.py similarity index 100% rename from tensorneat/utils/state.py rename to tensorneat/common/state.py diff --git a/tensorneat/utils/stateful_class.py b/tensorneat/common/stateful_class.py similarity index 100% rename from tensorneat/utils/stateful_class.py rename to tensorneat/common/stateful_class.py diff --git a/tensorneat/utils/tools.py b/tensorneat/common/tools.py similarity index 54% rename from tensorneat/utils/tools.py rename to tensorneat/common/tools.py index d1f3f22..b26ebe3 100644 --- a/tensorneat/utils/tools.py +++ b/tensorneat/common/tools.py @@ -6,36 +6,6 @@ from jax import numpy as jnp, Array, jit, vmap I_INF = np.iinfo(jnp.int32).max # infinite int - -def unflatten_conns(nodes, conns): - """ - transform the (C, CL) connections to (N, N), which contains the idx of the connection in conns - connection length, N means the number of nodes, C means the number of connections - returns the unflatten connection indices with shape (N, N) - """ - N = nodes.shape[0] # max_nodes - C = conns.shape[0] # max_conns - node_keys = nodes[:, 0] - i_keys, o_keys = conns[:, 0], conns[:, 1] - - def key_to_indices(key, keys): - return fetch_first(key == keys) - - i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys) - o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys) - - # Is interesting that jax use clip when attach data in array - # however, it will do nothing when setting values in an array - # put the index of connections in the unflatten array - unflatten = ( - jnp.full((N, N), I_INF, dtype=jnp.int32) - .at[i_idxs, o_idxs] - .set(jnp.arange(C, dtype=jnp.int32)) - ) - - return unflatten - - # TODO: strange implementation def attach_with_inf(arr, idx): expand_size = arr.ndim - idx.ndim @@ -45,40 +15,6 @@ def attach_with_inf(arr, idx): return jnp.where(expand_idx == I_INF, jnp.nan, arr[idx]) -def extract_node_attrs(node): - """ - node: Array(NL, ) - extract the attributes of a node - """ - return node[1:] # 0 is for idx - - -def set_node_attrs(node, attrs): - """ - node: Array(NL, ) - attrs: Array(NL-1, ) - set the attributes of a node - """ - return node.at[1:].set(attrs) # 0 is for idx - - -def extract_conn_attrs(conn): - """ - conn: Array(CL, ) - extract the attributes of a connection - """ - return conn[2:] # 0, 1 is for in-idx and out-idx - - -def set_conn_attrs(conn, attrs): - """ - conn: Array(CL, ) - attrs: Array(CL-2, ) - set the attributes of a connection - """ - return conn.at[2:].set(attrs) # 0, 1 is for in-idx and out-idx - - @jit def fetch_first(mask, default=I_INF) -> Array: """ @@ -164,44 +100,6 @@ def argmin_with_mask(arr, mask): return min_idx -def add_node(nodes, new_key: int, attrs): - """ - Add a new node to the genome. - The new node will place at the first NaN row. - """ - exist_keys = nodes[:, 0] - pos = fetch_first(jnp.isnan(exist_keys)) - new_nodes = nodes.at[pos, 0].set(new_key) - return new_nodes.at[pos, 1:].set(attrs) - - -def delete_node_by_pos(nodes, pos): - """ - Delete a node from the genome. - Delete the node by its pos in nodes. - """ - return nodes.at[pos].set(jnp.nan) - - -def add_conn(conns, i_key, o_key, attrs): - """ - Add a new connection to the genome. - The new connection will place at the first NaN row. - """ - con_keys = conns[:, 0] - pos = fetch_first(jnp.isnan(con_keys)) - new_conns = conns.at[pos, 0:2].set(jnp.array([i_key, o_key])) - return new_conns.at[pos, 2:].set(attrs) - - -def delete_conn_by_pos(conns, pos): - """ - Delete a connection from the genome. - Delete the connection by its idx. - """ - return conns.at[pos].set(jnp.nan) - - def hash_array(arr: Array): arr = jax.lax.bitcast_convert_type(arr, jnp.uint32) diff --git a/tensorneat/pipeline.py b/tensorneat/pipeline.py index 53d2b6b..5c2bae3 100644 --- a/tensorneat/pipeline.py +++ b/tensorneat/pipeline.py @@ -9,7 +9,7 @@ from algorithm import BaseAlgorithm from problem import BaseProblem from problem.rl_env import RLEnv from problem.func_fit import FuncFit -from utils import State, StatefulBaseClass +from tensorneat.common import State, StatefulBaseClass class Pipeline(StatefulBaseClass): diff --git a/tensorneat/problem/base.py b/tensorneat/problem/base.py index bf78f78..6bf6bc9 100644 --- a/tensorneat/problem/base.py +++ b/tensorneat/problem/base.py @@ -1,6 +1,6 @@ from typing import Callable -from utils import State, StatefulBaseClass +from tensorneat.common import State, StatefulBaseClass class BaseProblem(StatefulBaseClass): diff --git a/tensorneat/problem/func_fit/func_fit.py b/tensorneat/problem/func_fit/func_fit.py index dd19ac6..67ff86a 100644 --- a/tensorneat/problem/func_fit/func_fit.py +++ b/tensorneat/problem/func_fit/func_fit.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp -from utils import State +from tensorneat.common import State from .. import BaseProblem diff --git a/tensorneat/problem/rl_env/jumanji/jumanji_2048.py b/tensorneat/problem/rl_env/jumanji/jumanji_2048.py index 48d4bf3..97d3e85 100644 --- a/tensorneat/problem/rl_env/jumanji/jumanji_2048.py +++ b/tensorneat/problem/rl_env/jumanji/jumanji_2048.py @@ -1,7 +1,7 @@ import jax, jax.numpy as jnp import jumanji -from utils import State +from tensorneat.common import State from ..rl_jit import RLEnv diff --git a/tensorneat/problem/rl_env/rl_jit.py b/tensorneat/problem/rl_env/rl_jit.py index 709c613..00439c7 100644 --- a/tensorneat/problem/rl_env/rl_jit.py +++ b/tensorneat/problem/rl_env/rl_jit.py @@ -4,7 +4,7 @@ import jax import jax.numpy as jnp import numpy as np -from utils import State +from tensorneat.common import State from .. import BaseProblem diff --git a/tensorneat/test/crossover_mutation.py b/tensorneat/test/crossover_mutation.py index 2e6e0bf..9ae8ab5 100644 --- a/tensorneat/test/crossover_mutation.py +++ b/tensorneat/test/crossover_mutation.py @@ -1,5 +1,5 @@ import jax, jax.numpy as jnp -from utils import Act +from tensorneat.common import Act from algorithm.neat import * import numpy as np diff --git a/tensorneat/test/nan_fitness.py b/tensorneat/test/nan_fitness.py index 4b59f31..e0df6b7 100644 --- a/tensorneat/test/nan_fitness.py +++ b/tensorneat/test/nan_fitness.py @@ -1,5 +1,5 @@ import jax, jax.numpy as jnp -from utils import Act +from tensorneat.common import Act from algorithm.neat import * import numpy as np diff --git a/tensorneat/test/test_kan.ipynb b/tensorneat/test/test_kan.ipynb index e7b40ac..4508515 100644 --- a/tensorneat/test/test_kan.ipynb +++ b/tensorneat/test/test_kan.ipynb @@ -27,7 +27,7 @@ "from algorithm.neat.gene.node.kan_node import KANNode\n", "from algorithm.neat.gene.conn.bspline import BSplineConn\n", "from problem.func_fit import XOR3d\n", - "from utils import Act\n", + "from tensorneat.utils import Act\n", "\n", "import jax, jax.numpy as jnp\n", "\n", diff --git a/tensorneat/test/test_nan_fitness.py b/tensorneat/test/test_nan_fitness.py index c2575ef..45660aa 100644 --- a/tensorneat/test/test_nan_fitness.py +++ b/tensorneat/test/test_nan_fitness.py @@ -1,5 +1,5 @@ import jax, jax.numpy as jnp -from utils import Act +from tensorneat.common import Act from algorithm.neat import * import numpy as np diff --git a/tensorneat/test/test_record_episode.ipynb b/tensorneat/test/test_record_episode.ipynb index 08a300b..5d39d6e 100644 --- a/tensorneat/test/test_record_episode.ipynb +++ b/tensorneat/test/test_record_episode.ipynb @@ -14,7 +14,7 @@ "outputs": [], "source": [ "import jax, jax.numpy as jnp\n", - "from utils import State\n", + "from tensorneat.utils import State\n", "from problem.rl_env import BraxEnv\n", "\n", "\n", diff --git a/tensorneat/test/test_update_by_batch.ipynb b/tensorneat/test/test_update_by_batch.ipynb index 7c427a3..40e3967 100644 --- a/tensorneat/test/test_update_by_batch.ipynb +++ b/tensorneat/test/test_update_by_batch.ipynb @@ -145,7 +145,7 @@ "source": [ "from algorithm.neat.gene.node.normalized import NormalizedNode\n", "from algorithm.neat.gene.conn import DefaultConnGene\n", - "from utils import Act\n", + "from tensorneat.utils import Act\n", "\n", "genome = DefaultGenome(num_inputs=3, num_outputs=2, max_nodes=10, max_conns=10,\n", " node_gene=NormalizedNode(activation_default=Act.identity, activation_options=(Act.identity,)),\n",