update functions. Visualize, Interpretable and with evox
This commit is contained in:
@@ -20,16 +20,16 @@ if __name__ == "__main__":
|
||||
survival_threshold=0.1,
|
||||
compatibility_threshold=1.0,
|
||||
genome=DefaultGenome(
|
||||
max_nodes=100,
|
||||
max_nodes=50,
|
||||
max_conns=200,
|
||||
num_inputs=17,
|
||||
num_outputs=6,
|
||||
init_hidden_layers=(),
|
||||
node_gene=BiasNode(
|
||||
activation_options=ACT.tanh,
|
||||
activation_options=ACT.scaled_tanh,
|
||||
aggregation_options=AGG.sum,
|
||||
),
|
||||
output_transform=ACT.standard_tanh,
|
||||
output_transform=ACT.tanh,
|
||||
),
|
||||
),
|
||||
problem=BraxEnv(
|
||||
|
||||
@@ -20,7 +20,7 @@ if __name__ == "__main__":
|
||||
survival_threshold=0.1,
|
||||
compatibility_threshold=1.0,
|
||||
genome=DefaultGenome(
|
||||
max_nodes=100,
|
||||
max_nodes=50,
|
||||
max_conns=200,
|
||||
num_inputs=17,
|
||||
num_outputs=6,
|
||||
@@ -29,7 +29,7 @@ if __name__ == "__main__":
|
||||
activation_options=ACT.tanh,
|
||||
aggregation_options=AGG.sum,
|
||||
),
|
||||
output_transform=ACT.standard_tanh,
|
||||
output_transform=ACT.tanh,
|
||||
),
|
||||
),
|
||||
problem=BraxEnv(
|
||||
|
||||
@@ -6,7 +6,7 @@ from tensorneat.genome import DefaultGenome, DefaultNode, DefaultMutation, BiasN
|
||||
from tensorneat.problem.func_fit import CustomFuncFit
|
||||
from tensorneat.common import ACT, AGG
|
||||
|
||||
|
||||
# define a custom function fit problem
|
||||
def pagie_polynomial(inputs):
|
||||
x, y = inputs
|
||||
res = 1 / (1 + jnp.pow(x, -4)) + 1 / (1 + jnp.pow(y, -4))
|
||||
@@ -14,9 +14,12 @@ def pagie_polynomial(inputs):
|
||||
# important! returns an array, NOT a scalar
|
||||
return jnp.array([res])
|
||||
|
||||
# define custom activate function and register it
|
||||
def square(x):
|
||||
return x ** 2
|
||||
ACT.add_func("square", square)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
custom_problem = CustomFuncFit(
|
||||
func=pagie_polynomial,
|
||||
low_bounds=[-1, -1],
|
||||
|
||||
@@ -14,7 +14,7 @@ if __name__ == "__main__":
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
init_hidden_layers=(),
|
||||
output_transform=ACT.standard_sigmoid,
|
||||
output_transform=ACT.sigmoid,
|
||||
),
|
||||
),
|
||||
problem=XOR3d(),
|
||||
|
||||
@@ -22,12 +22,12 @@ if __name__ == "__main__":
|
||||
num_inputs=4, # size of query coors
|
||||
num_outputs=1,
|
||||
init_hidden_layers=(),
|
||||
output_transform=ACT.standard_tanh,
|
||||
output_transform=ACT.tanh,
|
||||
),
|
||||
),
|
||||
activation=ACT.tanh,
|
||||
activate_time=10,
|
||||
output_transform=ACT.standard_sigmoid,
|
||||
output_transform=ACT.sigmoid,
|
||||
),
|
||||
problem=XOR3d(),
|
||||
generation_limit=300,
|
||||
|
||||
@@ -14,7 +14,7 @@ if __name__ == "__main__":
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
init_hidden_layers=(),
|
||||
output_transform=ACT.standard_sigmoid,
|
||||
output_transform=ACT.sigmoid,
|
||||
activate_time=10,
|
||||
),
|
||||
),
|
||||
|
||||
@@ -27,7 +27,7 @@ if __name__ == "__main__":
|
||||
num_inputs=4, # size of query coors
|
||||
num_outputs=1,
|
||||
init_hidden_layers=(),
|
||||
output_transform=ACT.standard_tanh,
|
||||
output_transform=ACT.tanh,
|
||||
),
|
||||
),
|
||||
activation=ACT.tanh,
|
||||
|
||||
@@ -24,7 +24,7 @@ if __name__ == "__main__":
|
||||
activation_options=ACT.tanh,
|
||||
aggregation_options=AGG.sum,
|
||||
),
|
||||
output_transform=ACT.standard_tanh,
|
||||
output_transform=ACT.tanh,
|
||||
),
|
||||
),
|
||||
problem=GymNaxEnv(
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from algorithm.neat import *
|
||||
from algorithm.neat.genome.dense import DenseInitialize
|
||||
from utils.graph import topological_sort_python
|
||||
from tensorneat.genome import DefaultGenome
|
||||
from tensorneat.common import *
|
||||
from tensorneat.common.functions import SympySigmoid
|
||||
|
||||
if __name__ == "__main__":
|
||||
genome = DenseInitialize(
|
||||
genome = DefaultGenome(
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
max_nodes=50,
|
||||
max_conns=500,
|
||||
output_transform=ACT.sigmoid,
|
||||
)
|
||||
|
||||
state = genome.setup()
|
||||
@@ -22,7 +22,7 @@ if __name__ == "__main__":
|
||||
|
||||
input_idx, output_idx = genome.get_input_idx(), genome.get_output_idx()
|
||||
|
||||
res = genome.sympy_func(state, network, sympy_input_transform=lambda x: 999999999*x, sympy_output_transform=SympyStandardSigmoid)
|
||||
res = genome.sympy_func(state, network, sympy_input_transform=lambda x: 999*x, sympy_output_transform=SympySigmoid)
|
||||
(symbols,
|
||||
args_symbols,
|
||||
input_symbols,
|
||||
@@ -35,3 +35,11 @@ if __name__ == "__main__":
|
||||
|
||||
inputs = jnp.zeros(3)
|
||||
print(forward_func(inputs))
|
||||
|
||||
print(genome.forward(state, genome.transform(state, nodes, conns), inputs))
|
||||
|
||||
print(AGG.sympy_module("jax"))
|
||||
print(AGG.sympy_module("numpy"))
|
||||
|
||||
print(ACT.sympy_module("jax"))
|
||||
print(ACT.sympy_module("numpy"))
|
||||
@@ -1,30 +1,30 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from evox import workflows, algorithms, problems
|
||||
from evox import workflows, problems
|
||||
|
||||
from tensorneat.examples.with_evox.evox_algorithm_adaptor import EvoXAlgorithmAdaptor
|
||||
from tensorneat.examples.with_evox.tensorneat_monitor import TensorNEATMonitor
|
||||
from tensorneat.common.evox_adaptors import EvoXAlgorithmAdaptor, TensorNEATMonitor
|
||||
from tensorneat.algorithm import NEAT
|
||||
from tensorneat.algorithm.neat import DefaultSpecies, DefaultGenome, DefaultNodeGene
|
||||
from tensorneat.common import ACT
|
||||
from tensorneat.genome import DefaultGenome, BiasNode
|
||||
from tensorneat.common import ACT, AGG
|
||||
|
||||
neat_algorithm = NEAT(
|
||||
species=DefaultSpecies(
|
||||
pop_size=1000,
|
||||
species_size=20,
|
||||
survival_threshold=0.1,
|
||||
compatibility_threshold=1.0,
|
||||
genome=DefaultGenome(
|
||||
max_nodes=50,
|
||||
max_conns=200,
|
||||
num_inputs=17,
|
||||
num_outputs=6,
|
||||
max_nodes=200,
|
||||
max_conns=500,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_options=(ACT.standard_tanh,),
|
||||
activation_default=ACT.standard_tanh,
|
||||
init_hidden_layers=(),
|
||||
node_gene=BiasNode(
|
||||
activation_options=ACT.tanh,
|
||||
aggregation_options=AGG.sum,
|
||||
),
|
||||
output_transform=ACT.tanh,
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
),
|
||||
)
|
||||
evox_algorithm = EvoXAlgorithmAdaptor(neat_algorithm)
|
||||
|
||||
@@ -37,12 +37,13 @@ problem = problems.neuroevolution.Brax(
|
||||
policy=evox_algorithm.forward,
|
||||
max_episode_length=1000,
|
||||
num_episodes=1,
|
||||
backend="mjx"
|
||||
)
|
||||
|
||||
|
||||
def nan2inf(x):
|
||||
return jnp.where(jnp.isnan(x), -jnp.inf, x)
|
||||
|
||||
|
||||
# create a workflow
|
||||
workflow = workflows.StdWorkflow(
|
||||
algorithm=evox_algorithm,
|
||||
@@ -55,11 +56,11 @@ workflow = workflows.StdWorkflow(
|
||||
|
||||
# init the workflow
|
||||
state = workflow.init(workflow_key)
|
||||
# state = workflow.enable_multi_devices(state)
|
||||
# run the workflow for 100 steps
|
||||
import time
|
||||
|
||||
# enable multi devices
|
||||
state = workflow.enable_multi_devices(state)
|
||||
|
||||
# run the workflow for 100 steps
|
||||
for i in range(100):
|
||||
tic = time.time()
|
||||
train_info, state = workflow.step(state)
|
||||
monitor.show()
|
||||
@@ -35,6 +35,9 @@ dependencies = [
|
||||
"flax >= 0.8.4",
|
||||
"mujoco >= 3.1.4",
|
||||
"mujoco-mjx >= 3.1.4",
|
||||
"networkx >= 3.3",
|
||||
"matplotlib >= 3.9.0",
|
||||
"sympy >= 1.12.1",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@@ -1,180 +0,0 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: tensorneat
|
||||
Version: 0.1.0
|
||||
Summary: tensorneat
|
||||
Author-email: Lishuang Wang <wanglishuang22@gmail.com>
|
||||
License: BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2024, EMI-Group
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
Project-URL: Homepage, https://github.com/EMI-Group/tensorneat
|
||||
Project-URL: Bug Tracker, https://github.com/EMI-Group/tensorneat/issues
|
||||
Classifier: Programming Language :: Python :: 3
|
||||
Classifier: License :: OSI Approved :: BSD License
|
||||
Classifier: Intended Audience :: Science/Research
|
||||
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
||||
Requires-Python: >=3.9
|
||||
Description-Content-Type: text/markdown
|
||||
License-File: LICENSE
|
||||
Requires-Dist: brax>=0.10.3
|
||||
Requires-Dist: jax>=0.4.28
|
||||
Requires-Dist: gymnax>=0.0.8
|
||||
Requires-Dist: jaxopt>=0.8.3
|
||||
Requires-Dist: optax>=0.2.2
|
||||
Requires-Dist: flax>=0.8.4
|
||||
Requires-Dist: mujoco>=3.1.4
|
||||
Requires-Dist: mujoco-mjx>=3.1.4
|
||||
|
||||
<h1 align="center">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="./imgs/evox_logo_dark.png">
|
||||
<source media="(prefers-color-scheme: light)" srcset="./imgs/evox_logo_light.png">
|
||||
<a href="https://github.com/EMI-Group/evox">
|
||||
<img alt="EvoX Logo" height="50" src="./imgs/evox_logo_light.png">
|
||||
</a>
|
||||
</picture>
|
||||
<br>
|
||||
</h1>
|
||||
|
||||
<p align="center">
|
||||
🌟 TensorNEAT: Tensorized NEAT Implementation in JAX 🌟
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://arxiv.org/abs/2404.01817">
|
||||
<img src="https://img.shields.io/badge/paper-arxiv-red?style=for-the-badge" alt="TensorNEAT Paper on arXiv">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
## Introduction
|
||||
TensorNEAT is a JAX-based libaray for NeuroEvolution of Augmenting Topologies (NEAT) algorithms, focused on harnessing GPU acceleration to enhance the efficiency of evolving neural network structures for complex tasks. Its core mechanism involves the tensorization of network topologies, enabling parallel processing and significantly boosting computational speed and scalability by leveraging modern hardware accelerators. TensorNEAT is compatible with the [EvoX](https://github.com/EMI-Group/evox/) framewrok.
|
||||
|
||||
## Requirements
|
||||
Due to the rapid iteration of JAX versions, configuring the runtime environment for TensorNEAT can be challenging. We recommend the following versions for the relevant libraries:
|
||||
|
||||
- jax (0.4.28)
|
||||
- jaxlib (0.4.28+cuda12.cudnn89)
|
||||
- brax (0.10.3)
|
||||
- gymnax (0.0.8)
|
||||
|
||||
We provide detailed JAX-related environment references in [recommend_environment](recommend_environment.txt). If you encounter any issues while configuring the environment yourself, you can use this as a reference.
|
||||
|
||||
## Example
|
||||
Simple Example for XOR problem:
|
||||
```python
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.func_fit import XOR3d
|
||||
|
||||
if __name__ == '__main__':
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=DefaultGenome(
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
compatibility_threshold=3.5,
|
||||
),
|
||||
),
|
||||
problem=XOR3d(),
|
||||
generation_limit=10000,
|
||||
fitness_target=-1e-8
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
# show result
|
||||
pipeline.show(state, best)
|
||||
```
|
||||
|
||||
Simple Example for RL envs in Brax (Ant):
|
||||
```python
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import BraxEnv
|
||||
from tensorneat.utils import ACT
|
||||
|
||||
if __name__ == '__main__':
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=DefaultGenome(
|
||||
num_inputs=27,
|
||||
num_outputs=8,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_options=(ACT.tanh,),
|
||||
activation_default=ACT.tanh,
|
||||
)
|
||||
),
|
||||
pop_size=1000,
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=BraxEnv(
|
||||
env_name='ant',
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=5000
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
```
|
||||
|
||||
more examples are in `tensorneat/examples`.
|
||||
|
||||
## Community & Support
|
||||
|
||||
- Engage in discussions and share your experiences on [GitHub Discussion Board](https://github.com/EMI-Group/evox/discussions).
|
||||
- Join our QQ group (ID: 297969717).
|
||||
|
||||
## Citing TensorNEAT
|
||||
|
||||
If you use TensorNEAT in your research and want to cite it in your work, please use:
|
||||
```
|
||||
@article{tensorneat,
|
||||
title = {{Tensorized} {NeuroEvolution} of {Augmenting} {Topologies} for {GPU} {Acceleration}},
|
||||
author = {Wang, Lishuang and Zhao, Mengfei and Liu, Enyu and Sun, Kebin and Cheng, Ran},
|
||||
booktitle = {Proceedings of the Genetic and Evolutionary Computation Conference (GECCO)},
|
||||
year = {2024}
|
||||
}
|
||||
@@ -1,69 +0,0 @@
|
||||
LICENSE
|
||||
README.md
|
||||
pyproject.toml
|
||||
src/tensorneat/pipeline.py
|
||||
src/tensorneat.egg-info/PKG-INFO
|
||||
src/tensorneat.egg-info/SOURCES.txt
|
||||
src/tensorneat.egg-info/dependency_links.txt
|
||||
src/tensorneat.egg-info/requires.txt
|
||||
src/tensorneat.egg-info/top_level.txt
|
||||
src/tensorneat/algorithm/__init__.py
|
||||
src/tensorneat/algorithm/base.py
|
||||
src/tensorneat/algorithm/hyperneat/__init__.py
|
||||
src/tensorneat/algorithm/hyperneat/hyperneat.py
|
||||
src/tensorneat/algorithm/hyperneat/substrate/__init__.py
|
||||
src/tensorneat/algorithm/hyperneat/substrate/base.py
|
||||
src/tensorneat/algorithm/hyperneat/substrate/default.py
|
||||
src/tensorneat/algorithm/hyperneat/substrate/full.py
|
||||
src/tensorneat/algorithm/neat/__init__.py
|
||||
src/tensorneat/algorithm/neat/neat.py
|
||||
src/tensorneat/algorithm/neat/species.py
|
||||
src/tensorneat/common/__init__.py
|
||||
src/tensorneat/common/graph.py
|
||||
src/tensorneat/common/state.py
|
||||
src/tensorneat/common/stateful_class.py
|
||||
src/tensorneat/common/tools.py
|
||||
src/tensorneat/common/functions/__init__.py
|
||||
src/tensorneat/common/functions/act_jnp.py
|
||||
src/tensorneat/common/functions/act_sympy.py
|
||||
src/tensorneat/common/functions/agg_jnp.py
|
||||
src/tensorneat/common/functions/agg_sympy.py
|
||||
src/tensorneat/common/functions/manager.py
|
||||
src/tensorneat/genome/__init__.py
|
||||
src/tensorneat/genome/base.py
|
||||
src/tensorneat/genome/default.py
|
||||
src/tensorneat/genome/recurrent.py
|
||||
src/tensorneat/genome/utils.py
|
||||
src/tensorneat/genome/gene/__init__.py
|
||||
src/tensorneat/genome/gene/base.py
|
||||
src/tensorneat/genome/gene/conn/__init__.py
|
||||
src/tensorneat/genome/gene/conn/base.py
|
||||
src/tensorneat/genome/gene/conn/default.py
|
||||
src/tensorneat/genome/gene/node/__init__.py
|
||||
src/tensorneat/genome/gene/node/base.py
|
||||
src/tensorneat/genome/gene/node/bias.py
|
||||
src/tensorneat/genome/gene/node/default.py
|
||||
src/tensorneat/genome/operations/__init__.py
|
||||
src/tensorneat/genome/operations/crossover/__init__.py
|
||||
src/tensorneat/genome/operations/crossover/base.py
|
||||
src/tensorneat/genome/operations/crossover/default.py
|
||||
src/tensorneat/genome/operations/distance/__init__.py
|
||||
src/tensorneat/genome/operations/distance/base.py
|
||||
src/tensorneat/genome/operations/distance/default.py
|
||||
src/tensorneat/genome/operations/mutation/__init__.py
|
||||
src/tensorneat/genome/operations/mutation/base.py
|
||||
src/tensorneat/genome/operations/mutation/default.py
|
||||
src/tensorneat/problem/__init__.py
|
||||
src/tensorneat/problem/base.py
|
||||
src/tensorneat/problem/func_fit/__init__.py
|
||||
src/tensorneat/problem/func_fit/custom.py
|
||||
src/tensorneat/problem/func_fit/func_fit.py
|
||||
src/tensorneat/problem/func_fit/xor.py
|
||||
src/tensorneat/problem/func_fit/xor3d.py
|
||||
src/tensorneat/problem/rl/__init__.py
|
||||
src/tensorneat/problem/rl/brax.py
|
||||
src/tensorneat/problem/rl/gymnax.py
|
||||
src/tensorneat/problem/rl/rl_jit.py
|
||||
test/test_genome.py
|
||||
test/test_nan_fitness.py
|
||||
test/test_record_episode.py
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
brax>=0.10.3
|
||||
jax>=0.4.28
|
||||
gymnax>=0.0.8
|
||||
jaxopt>=0.8.3
|
||||
optax>=0.2.2
|
||||
flax>=0.8.4
|
||||
mujoco>=3.1.4
|
||||
mujoco-mjx>=3.1.4
|
||||
@@ -1 +0,0 @@
|
||||
tensorneat
|
||||
@@ -19,7 +19,7 @@ class HyperNEAT(BaseAlgorithm):
|
||||
aggregation: Callable = AGG.sum,
|
||||
activation: Callable = ACT.sigmoid,
|
||||
activate_time: int = 10,
|
||||
output_transform: Callable = ACT.standard_sigmoid,
|
||||
output_transform: Callable = ACT.sigmoid,
|
||||
):
|
||||
assert (
|
||||
substrate.query_coors.shape[1] == neat.num_inputs
|
||||
|
||||
@@ -3,4 +3,4 @@ from .graph import *
|
||||
from .state import State
|
||||
from .stateful_class import StatefulBaseClass
|
||||
|
||||
from .functions import ACT, AGG, apply_activation, apply_aggregation
|
||||
from .functions import ACT, AGG, apply_activation, apply_aggregation, get_func_name
|
||||
|
||||
2
src/tensorneat/common/evox_adaptors/__init__.py
Normal file
2
src/tensorneat/common/evox_adaptors/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .algorithm_adaptor import EvoXAlgorithmAdaptor
|
||||
from .tensorneat_monitor import TensorNEATMonitor
|
||||
@@ -16,12 +16,12 @@ class TensorNEATMonitor(Monitor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
neat_algorithm: TensorNEATAlgorithm,
|
||||
tensorneat_algorithm: TensorNEATAlgorithm,
|
||||
save_dir: str = None,
|
||||
is_save: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.neat_algorithm = neat_algorithm
|
||||
self.tensorneat_algorithm = tensorneat_algorithm
|
||||
|
||||
self.generation_timestamp = time.time()
|
||||
self.alg_state: TensorNEATState = None
|
||||
@@ -60,7 +60,9 @@ class TensorNEATMonitor(Monitor):
|
||||
self.fitness = jax.device_get(fitness)
|
||||
|
||||
def show(self):
|
||||
pop = self.neat_algorithm.ask(self.alg_state)
|
||||
pop = self.tensorneat_algorithm.ask(self.alg_state)
|
||||
generation = int(self.alg_state.generation)
|
||||
|
||||
valid_fitnesses = self.fitness[~np.isinf(self.fitness)]
|
||||
|
||||
max_f, min_f, mean_f, std_f = (
|
||||
@@ -73,7 +75,7 @@ class TensorNEATMonitor(Monitor):
|
||||
new_timestamp = time.time()
|
||||
|
||||
cost_time = new_timestamp - self.generation_timestamp
|
||||
self.generation_timestamp = new_timestamp
|
||||
self.generation_timestamp = time.time()
|
||||
|
||||
max_idx = np.argmax(self.fitness)
|
||||
if self.fitness[max_idx] > self.best_fitness:
|
||||
@@ -81,14 +83,12 @@ class TensorNEATMonitor(Monitor):
|
||||
self.best_genome = pop[0][max_idx], pop[1][max_idx]
|
||||
|
||||
if self.is_save:
|
||||
# save best
|
||||
best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx]))
|
||||
with open(
|
||||
os.path.join(
|
||||
self.genome_dir,
|
||||
f"{int(self.neat_algorithm.generation(self.alg_state))}.npz",
|
||||
),
|
||||
"wb",
|
||||
) as f:
|
||||
file_name = os.path.join(
|
||||
self.genome_dir, f"{generation}.npz"
|
||||
)
|
||||
with open(file_name, "wb") as f:
|
||||
np.savez(
|
||||
f,
|
||||
nodes=best_genome[0],
|
||||
@@ -96,38 +96,15 @@ class TensorNEATMonitor(Monitor):
|
||||
fitness=self.best_fitness,
|
||||
)
|
||||
|
||||
# save best if save path is not None
|
||||
member_count = jax.device_get(self.neat_algorithm.member_count(self.alg_state))
|
||||
species_sizes = [int(i) for i in member_count if i > 0]
|
||||
|
||||
pop = jax.device_get(pop)
|
||||
pop_nodes, pop_conns = pop # (P, N, NL), (P, C, CL)
|
||||
nodes_cnt = (~np.isnan(pop_nodes[:, :, 0])).sum(axis=1) # (P,)
|
||||
conns_cnt = (~np.isnan(pop_conns[:, :, 0])).sum(axis=1) # (P,)
|
||||
|
||||
max_node_cnt, min_node_cnt, mean_node_cnt = (
|
||||
max(nodes_cnt),
|
||||
min(nodes_cnt),
|
||||
np.mean(nodes_cnt),
|
||||
)
|
||||
|
||||
max_conn_cnt, min_conn_cnt, mean_conn_cnt = (
|
||||
max(conns_cnt),
|
||||
min(conns_cnt),
|
||||
np.mean(conns_cnt),
|
||||
# append log
|
||||
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
|
||||
f.write(
|
||||
f"{generation},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Generation: {self.neat_algorithm.generation(self.alg_state)}, Cost time: {cost_time * 1000:.2f}ms\n",
|
||||
f"\tnode counts: max: {max_node_cnt}, min: {min_node_cnt}, mean: {mean_node_cnt:.2f}\n",
|
||||
f"\tconn counts: max: {max_conn_cnt}, min: {min_conn_cnt}, mean: {mean_conn_cnt:.2f}\n",
|
||||
f"\tspecies: {len(species_sizes)}, {species_sizes}\n",
|
||||
f"Generation: {generation}, Cost time: {cost_time * 1000:.2f}ms\n",
|
||||
f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
|
||||
)
|
||||
|
||||
# append log
|
||||
if self.is_save:
|
||||
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
|
||||
f.write(
|
||||
f"{self.neat_algorithm.generation(self.alg_state)},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
|
||||
)
|
||||
self.tensorneat_algorithm.show_details(self.alg_state, self.fitness)
|
||||
@@ -1,3 +1,5 @@
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from .act_jnp import *
|
||||
from .act_sympy import *
|
||||
from .agg_jnp import *
|
||||
@@ -32,6 +34,7 @@ act_name2sympy = {
|
||||
"log": SympyLog,
|
||||
"exp": SympyExp,
|
||||
"abs": SympyAbs,
|
||||
"clip": SympyClip,
|
||||
}
|
||||
|
||||
agg_name2jnp = {
|
||||
@@ -40,7 +43,6 @@ agg_name2jnp = {
|
||||
"max": max_,
|
||||
"min": min_,
|
||||
"maxabs": maxabs_,
|
||||
"median": median_,
|
||||
"mean": mean_,
|
||||
}
|
||||
|
||||
@@ -50,9 +52,42 @@ agg_name2sympy = {
|
||||
"max": SympyMax,
|
||||
"min": SympyMin,
|
||||
"maxabs": SympyMaxabs,
|
||||
"median": SympyMedian,
|
||||
"mean": SympyMean,
|
||||
}
|
||||
|
||||
ACT = FunctionManager(act_name2jnp, act_name2sympy)
|
||||
AGG = FunctionManager(agg_name2jnp, agg_name2sympy)
|
||||
|
||||
def apply_activation(idx, z, act_funcs):
|
||||
"""
|
||||
calculate activation function for each node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
# change idx from float to int
|
||||
|
||||
# -1 means identity activation
|
||||
res = jax.lax.cond(
|
||||
idx == -1,
|
||||
lambda: z,
|
||||
lambda: jax.lax.switch(idx, act_funcs, z),
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
def apply_aggregation(idx, z, agg_funcs):
|
||||
"""
|
||||
calculate activation function for inputs of node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
|
||||
return jax.lax.cond(
|
||||
jnp.all(jnp.isnan(z)),
|
||||
lambda: jnp.nan, # all inputs are nan
|
||||
lambda: jax.lax.switch(idx, agg_funcs, z), # otherwise
|
||||
)
|
||||
|
||||
def get_func_name(func):
|
||||
name = func.__name__
|
||||
if name.endswith("_"):
|
||||
name = name[:-1]
|
||||
return name
|
||||
@@ -1,6 +1,6 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
SCALE = 5
|
||||
SCALE = 3
|
||||
|
||||
|
||||
def scaled_sigmoid_(z):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import sympy as sp
|
||||
import numpy as np
|
||||
|
||||
SCALE = 5
|
||||
SCALE = 3
|
||||
|
||||
class SympySigmoid(sp.Function):
|
||||
@classmethod
|
||||
|
||||
@@ -23,18 +23,6 @@ def maxabs_(z):
|
||||
max_abs_index = jnp.argmax(abs_z)
|
||||
return z[max_abs_index]
|
||||
|
||||
|
||||
def median_(z):
|
||||
n = jnp.sum(~jnp.isnan(z), axis=0)
|
||||
|
||||
z = jnp.sort(z) # sort
|
||||
|
||||
idx1, idx2 = (n - 1) // 2, n // 2
|
||||
median = (z[idx1] + z[idx2]) / 2
|
||||
|
||||
return median
|
||||
|
||||
|
||||
def mean_(z):
|
||||
sumation = sum_(z)
|
||||
valid_count = jnp.sum(~jnp.isnan(z), axis=0)
|
||||
|
||||
@@ -7,30 +7,18 @@ class SympySum(sp.Function):
|
||||
def eval(cls, z):
|
||||
return sp.Add(*z)
|
||||
|
||||
@classmethod
|
||||
def numerical_eval(cls, z, backend=np):
|
||||
return backend.sum(z)
|
||||
|
||||
|
||||
class SympyProduct(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Mul(*z)
|
||||
|
||||
@classmethod
|
||||
def numerical_eval(cls, z, backend=np):
|
||||
return backend.product(z)
|
||||
|
||||
|
||||
class SympyMax(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Max(*z)
|
||||
|
||||
@classmethod
|
||||
def numerical_eval(cls, z, backend=np):
|
||||
return backend.max(z)
|
||||
|
||||
|
||||
class SympyMin(sp.Function):
|
||||
@classmethod
|
||||
@@ -48,26 +36,3 @@ class SympyMean(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Add(*z) / len(z)
|
||||
|
||||
|
||||
class SympyMedian(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, args):
|
||||
|
||||
if all(arg.is_number for arg in args):
|
||||
sorted_args = sorted(args)
|
||||
n = len(sorted_args)
|
||||
if n % 2 == 1:
|
||||
return sorted_args[n // 2]
|
||||
else:
|
||||
return (sorted_args[n // 2 - 1] + sorted_args[n // 2]) / 2
|
||||
|
||||
return None
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"median({', '.join(map(str, self.args))})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return (
|
||||
r"\mathrm{median}\left(" + ", ".join(map(sp.latex, self.args)) + r"\right)"
|
||||
)
|
||||
|
||||
@@ -1,28 +1,32 @@
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
from typing import Union, Callable
|
||||
import sympy as sp
|
||||
|
||||
|
||||
class FunctionManager:
|
||||
|
||||
def __init__(self, name2jnp, name2sympy):
|
||||
self.name2jnp = name2jnp
|
||||
self.name2sympy = name2sympy
|
||||
for name, func in name2jnp.items():
|
||||
setattr(self, name, func)
|
||||
|
||||
def get_all_funcs(self):
|
||||
all_funcs = []
|
||||
for name in self.names:
|
||||
for name in self.name2jnp:
|
||||
all_funcs.append(getattr(self, name))
|
||||
return all_funcs
|
||||
|
||||
def __getattribute__(self, name: str):
|
||||
return self.name2jnp[name]
|
||||
|
||||
def add_func(self, name, func):
|
||||
if not callable(func):
|
||||
raise ValueError("The provided function is not callable")
|
||||
if name in self.names:
|
||||
if name in self.name2jnp:
|
||||
raise ValueError(f"The provided name={name} is already in use")
|
||||
|
||||
self.name2jnp[name] = func
|
||||
setattr(self, name, func)
|
||||
|
||||
def update_sympy(self, name, sympy_cls: sp.Function):
|
||||
self.name2sympy[name] = sympy_cls
|
||||
@@ -47,3 +51,16 @@ class FunctionManager:
|
||||
if name not in self.name2sympy:
|
||||
raise ValueError(f"Func {name} doesn't have a sympy representation.")
|
||||
return self.name2sympy[name]
|
||||
|
||||
def sympy_module(self, backend: str):
|
||||
assert backend in ["jax", "numpy"]
|
||||
if backend == "jax":
|
||||
backend = jnp
|
||||
elif backend == "numpy":
|
||||
backend = np
|
||||
module = {}
|
||||
for sympy_cls in self.name2sympy.values():
|
||||
if hasattr(sympy_cls, "numerical_eval"):
|
||||
module[sympy_cls.__name__] = partial(sympy_cls.numerical_eval, backend)
|
||||
|
||||
return module
|
||||
|
||||
@@ -15,8 +15,8 @@ from tensorneat.common import (
|
||||
topological_sort_python,
|
||||
I_INF,
|
||||
attach_with_inf,
|
||||
SYMPY_FUNCS_MODULE_NP,
|
||||
SYMPY_FUNCS_MODULE_JNP,
|
||||
ACT,
|
||||
AGG
|
||||
)
|
||||
|
||||
|
||||
@@ -92,7 +92,9 @@ class DefaultGenome(BaseGenome):
|
||||
def otherwise():
|
||||
# calculate connections
|
||||
conn_indices = u_conns[:, i]
|
||||
hit_attrs = attach_with_inf(conns_attrs, conn_indices) # fetch conn attrs
|
||||
hit_attrs = attach_with_inf(
|
||||
conns_attrs, conn_indices
|
||||
) # fetch conn attrs
|
||||
ins = vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
|
||||
state, hit_attrs, values
|
||||
)
|
||||
@@ -102,7 +104,9 @@ class DefaultGenome(BaseGenome):
|
||||
state,
|
||||
nodes_attrs[i],
|
||||
ins,
|
||||
is_output_node=jnp.isin(nodes[i, 0], self.output_idx), # nodes[0] -> the key of nodes
|
||||
is_output_node=jnp.isin(
|
||||
nodes[i, 0], self.output_idx
|
||||
), # nodes[0] -> the key of nodes
|
||||
)
|
||||
|
||||
# set new value
|
||||
@@ -139,7 +143,6 @@ class DefaultGenome(BaseGenome):
|
||||
):
|
||||
|
||||
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
|
||||
module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP
|
||||
|
||||
if sympy_input_transform is None and self.input_transform is not None:
|
||||
warnings.warn(
|
||||
@@ -224,7 +227,7 @@ class DefaultGenome(BaseGenome):
|
||||
sp.lambdify(
|
||||
input_symbols + list(args_symbols.keys()),
|
||||
exprs,
|
||||
modules=[backend, module],
|
||||
modules=[backend, AGG.sympy_module(backend), ACT.sympy_module(backend)],
|
||||
)
|
||||
for exprs in output_exprs
|
||||
]
|
||||
@@ -256,7 +259,12 @@ class DefaultGenome(BaseGenome):
|
||||
rotate=0,
|
||||
reverse_node_order=False,
|
||||
size=(300, 300, 300),
|
||||
color=("blue", "blue", "blue"),
|
||||
color=("yellow", "white", "blue"),
|
||||
with_labels=False,
|
||||
edgecolors="k",
|
||||
arrowstyle="->",
|
||||
arrowsize=3,
|
||||
edge_color=(0.3, 0.3, 0.3),
|
||||
save_path="network.svg",
|
||||
save_dpi=800,
|
||||
**kwargs,
|
||||
@@ -264,7 +272,6 @@ class DefaultGenome(BaseGenome):
|
||||
import networkx as nx
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
nodes_list = list(network["nodes"])
|
||||
conns_list = list(network["conns"])
|
||||
input_idx = self.get_input_idx()
|
||||
output_idx = self.get_output_idx()
|
||||
@@ -316,6 +323,11 @@ class DefaultGenome(BaseGenome):
|
||||
pos=rotated_pos,
|
||||
node_size=node_sizes,
|
||||
node_color=node_colors,
|
||||
with_labels=with_labels,
|
||||
edgecolors=edgecolors,
|
||||
arrowstyle=arrowstyle,
|
||||
arrowsize=arrowsize,
|
||||
edge_color=edge_color,
|
||||
**kwargs,
|
||||
)
|
||||
plt.savefig(save_path, dpi=save_dpi)
|
||||
|
||||
@@ -10,7 +10,7 @@ from tensorneat.common import (
|
||||
apply_aggregation,
|
||||
mutate_int,
|
||||
mutate_float,
|
||||
convert_to_sympy,
|
||||
get_func_name
|
||||
)
|
||||
|
||||
from . import BaseNode
|
||||
@@ -141,8 +141,8 @@ class BiasNode(BaseNode):
|
||||
self.__class__.__name__,
|
||||
idx,
|
||||
bias,
|
||||
self.aggregation_options[agg].__name__,
|
||||
act_func.__name__,
|
||||
get_func_name(self.aggregation_options[agg]),
|
||||
get_func_name(act_func),
|
||||
idx_width=idx_width,
|
||||
float_width=precision + 3,
|
||||
func_width=func_width,
|
||||
@@ -165,21 +165,19 @@ class BiasNode(BaseNode):
|
||||
return {
|
||||
"idx": idx,
|
||||
"bias": bias,
|
||||
"agg": self.aggregation_options[int(agg)].__name__,
|
||||
"act": act_func.__name__,
|
||||
"agg": get_func_name(self.aggregation_options[agg]),
|
||||
"act": get_func_name(act_func),
|
||||
}
|
||||
|
||||
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
|
||||
nd = node_dict
|
||||
bias = sp.symbols(f"n_{node_dict['idx']}_b")
|
||||
|
||||
bias = sp.symbols(f"n_{nd['idx']}_b")
|
||||
|
||||
z = convert_to_sympy(nd["agg"])(inputs)
|
||||
z = AGG.obtain_sympy(node_dict["agg"])(inputs)
|
||||
|
||||
z = bias + z
|
||||
if is_output_node:
|
||||
pass
|
||||
else:
|
||||
z = convert_to_sympy(nd["act"])(z)
|
||||
z = ACT.obtain_sympy(node_dict["act"])(z)
|
||||
|
||||
return z, {bias: nd["bias"]}
|
||||
return z, {bias: node_dict["bias"]}
|
||||
|
||||
@@ -11,7 +11,7 @@ from tensorneat.common import (
|
||||
apply_aggregation,
|
||||
mutate_int,
|
||||
mutate_float,
|
||||
convert_to_sympy,
|
||||
get_func_name
|
||||
)
|
||||
|
||||
from .base import BaseNode
|
||||
@@ -176,8 +176,8 @@ class DefaultNode(BaseNode):
|
||||
idx,
|
||||
bias,
|
||||
res,
|
||||
self.aggregation_options[agg].__name__,
|
||||
act_func.__name__,
|
||||
get_func_name(self.aggregation_options[agg]),
|
||||
get_func_name(act_func),
|
||||
idx_width=idx_width,
|
||||
float_width=precision + 3,
|
||||
func_width=func_width,
|
||||
@@ -200,8 +200,8 @@ class DefaultNode(BaseNode):
|
||||
"idx": idx,
|
||||
"bias": bias,
|
||||
"res": res,
|
||||
"agg": self.aggregation_options[int(agg)].__name__,
|
||||
"act": act_func.__name__,
|
||||
"agg": get_func_name(self.aggregation_options[agg]),
|
||||
"act": get_func_name(act_func),
|
||||
}
|
||||
|
||||
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
|
||||
@@ -209,12 +209,13 @@ class DefaultNode(BaseNode):
|
||||
bias = sp.symbols(f"n_{nd['idx']}_b")
|
||||
res = sp.symbols(f"n_{nd['idx']}_r")
|
||||
|
||||
z = convert_to_sympy(nd["agg"])(inputs)
|
||||
print(nd["agg"])
|
||||
z = AGG.obtain_sympy(nd["agg"])(inputs)
|
||||
z = bias + res * z
|
||||
|
||||
if is_output_node:
|
||||
pass
|
||||
else:
|
||||
z = convert_to_sympy(nd["act"])(z)
|
||||
z = ACT.obtain_sympy(nd["act"])(z)
|
||||
|
||||
return z, {bias: nd["bias"], res: nd["res"]}
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"cells": [],
|
||||
"metadata": {},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
23
tutorials/tutorial-01-genome.ipynb
Normal file
23
tutorials/tutorial-01-genome.ipynb
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "286b5fc2-e629-4461-ad95-b69d3d606a78",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "",
|
||||
"name": ""
|
||||
},
|
||||
"language_info": {
|
||||
"name": ""
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user