update functions. Visualize, Interpretable and with evox

This commit is contained in:
root
2024-07-12 04:35:22 +08:00
parent 5fc63fdaf1
commit 0d6e7477bf
32 changed files with 207 additions and 427 deletions

View File

@@ -20,16 +20,16 @@ if __name__ == "__main__":
survival_threshold=0.1, survival_threshold=0.1,
compatibility_threshold=1.0, compatibility_threshold=1.0,
genome=DefaultGenome( genome=DefaultGenome(
max_nodes=100, max_nodes=50,
max_conns=200, max_conns=200,
num_inputs=17, num_inputs=17,
num_outputs=6, num_outputs=6,
init_hidden_layers=(), init_hidden_layers=(),
node_gene=BiasNode( node_gene=BiasNode(
activation_options=ACT.tanh, activation_options=ACT.scaled_tanh,
aggregation_options=AGG.sum, aggregation_options=AGG.sum,
), ),
output_transform=ACT.standard_tanh, output_transform=ACT.tanh,
), ),
), ),
problem=BraxEnv( problem=BraxEnv(

View File

@@ -20,7 +20,7 @@ if __name__ == "__main__":
survival_threshold=0.1, survival_threshold=0.1,
compatibility_threshold=1.0, compatibility_threshold=1.0,
genome=DefaultGenome( genome=DefaultGenome(
max_nodes=100, max_nodes=50,
max_conns=200, max_conns=200,
num_inputs=17, num_inputs=17,
num_outputs=6, num_outputs=6,
@@ -29,7 +29,7 @@ if __name__ == "__main__":
activation_options=ACT.tanh, activation_options=ACT.tanh,
aggregation_options=AGG.sum, aggregation_options=AGG.sum,
), ),
output_transform=ACT.standard_tanh, output_transform=ACT.tanh,
), ),
), ),
problem=BraxEnv( problem=BraxEnv(

View File

@@ -6,7 +6,7 @@ from tensorneat.genome import DefaultGenome, DefaultNode, DefaultMutation, BiasN
from tensorneat.problem.func_fit import CustomFuncFit from tensorneat.problem.func_fit import CustomFuncFit
from tensorneat.common import ACT, AGG from tensorneat.common import ACT, AGG
# define a custom function fit problem
def pagie_polynomial(inputs): def pagie_polynomial(inputs):
x, y = inputs x, y = inputs
res = 1 / (1 + jnp.pow(x, -4)) + 1 / (1 + jnp.pow(y, -4)) res = 1 / (1 + jnp.pow(x, -4)) + 1 / (1 + jnp.pow(y, -4))
@@ -14,9 +14,12 @@ def pagie_polynomial(inputs):
# important! returns an array, NOT a scalar # important! returns an array, NOT a scalar
return jnp.array([res]) return jnp.array([res])
# define custom activate function and register it
def square(x):
return x ** 2
ACT.add_func("square", square)
if __name__ == "__main__": if __name__ == "__main__":
custom_problem = CustomFuncFit( custom_problem = CustomFuncFit(
func=pagie_polynomial, func=pagie_polynomial,
low_bounds=[-1, -1], low_bounds=[-1, -1],

View File

@@ -14,7 +14,7 @@ if __name__ == "__main__":
num_inputs=3, num_inputs=3,
num_outputs=1, num_outputs=1,
init_hidden_layers=(), init_hidden_layers=(),
output_transform=ACT.standard_sigmoid, output_transform=ACT.sigmoid,
), ),
), ),
problem=XOR3d(), problem=XOR3d(),

View File

@@ -22,12 +22,12 @@ if __name__ == "__main__":
num_inputs=4, # size of query coors num_inputs=4, # size of query coors
num_outputs=1, num_outputs=1,
init_hidden_layers=(), init_hidden_layers=(),
output_transform=ACT.standard_tanh, output_transform=ACT.tanh,
), ),
), ),
activation=ACT.tanh, activation=ACT.tanh,
activate_time=10, activate_time=10,
output_transform=ACT.standard_sigmoid, output_transform=ACT.sigmoid,
), ),
problem=XOR3d(), problem=XOR3d(),
generation_limit=300, generation_limit=300,

View File

@@ -14,7 +14,7 @@ if __name__ == "__main__":
num_inputs=3, num_inputs=3,
num_outputs=1, num_outputs=1,
init_hidden_layers=(), init_hidden_layers=(),
output_transform=ACT.standard_sigmoid, output_transform=ACT.sigmoid,
activate_time=10, activate_time=10,
), ),
), ),

View File

@@ -27,7 +27,7 @@ if __name__ == "__main__":
num_inputs=4, # size of query coors num_inputs=4, # size of query coors
num_outputs=1, num_outputs=1,
init_hidden_layers=(), init_hidden_layers=(),
output_transform=ACT.standard_tanh, output_transform=ACT.tanh,
), ),
), ),
activation=ACT.tanh, activation=ACT.tanh,

View File

@@ -24,7 +24,7 @@ if __name__ == "__main__":
activation_options=ACT.tanh, activation_options=ACT.tanh,
aggregation_options=AGG.sum, aggregation_options=AGG.sum,
), ),
output_transform=ACT.standard_tanh, output_transform=ACT.tanh,
), ),
), ),
problem=GymNaxEnv( problem=GymNaxEnv(

View File

@@ -1,16 +1,16 @@
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from algorithm.neat import * from tensorneat.genome import DefaultGenome
from algorithm.neat.genome.dense import DenseInitialize
from utils.graph import topological_sort_python
from tensorneat.common import * from tensorneat.common import *
from tensorneat.common.functions import SympySigmoid
if __name__ == "__main__": if __name__ == "__main__":
genome = DenseInitialize( genome = DefaultGenome(
num_inputs=3, num_inputs=3,
num_outputs=1, num_outputs=1,
max_nodes=50, max_nodes=50,
max_conns=500, max_conns=500,
output_transform=ACT.sigmoid,
) )
state = genome.setup() state = genome.setup()
@@ -22,7 +22,7 @@ if __name__ == "__main__":
input_idx, output_idx = genome.get_input_idx(), genome.get_output_idx() input_idx, output_idx = genome.get_input_idx(), genome.get_output_idx()
res = genome.sympy_func(state, network, sympy_input_transform=lambda x: 999999999*x, sympy_output_transform=SympyStandardSigmoid) res = genome.sympy_func(state, network, sympy_input_transform=lambda x: 999*x, sympy_output_transform=SympySigmoid)
(symbols, (symbols,
args_symbols, args_symbols,
input_symbols, input_symbols,
@@ -35,3 +35,11 @@ if __name__ == "__main__":
inputs = jnp.zeros(3) inputs = jnp.zeros(3)
print(forward_func(inputs)) print(forward_func(inputs))
print(genome.forward(state, genome.transform(state, nodes, conns), inputs))
print(AGG.sympy_module("jax"))
print(AGG.sympy_module("numpy"))
print(ACT.sympy_module("jax"))
print(ACT.sympy_module("numpy"))

View File

@@ -1,30 +1,30 @@
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from evox import workflows, algorithms, problems from evox import workflows, problems
from tensorneat.examples.with_evox.evox_algorithm_adaptor import EvoXAlgorithmAdaptor from tensorneat.common.evox_adaptors import EvoXAlgorithmAdaptor, TensorNEATMonitor
from tensorneat.examples.with_evox.tensorneat_monitor import TensorNEATMonitor
from tensorneat.algorithm import NEAT from tensorneat.algorithm import NEAT
from tensorneat.algorithm.neat import DefaultSpecies, DefaultGenome, DefaultNodeGene from tensorneat.genome import DefaultGenome, BiasNode
from tensorneat.common import ACT from tensorneat.common import ACT, AGG
neat_algorithm = NEAT( neat_algorithm = NEAT(
species=DefaultSpecies( pop_size=1000,
species_size=20,
survival_threshold=0.1,
compatibility_threshold=1.0,
genome=DefaultGenome( genome=DefaultGenome(
max_nodes=50,
max_conns=200,
num_inputs=17, num_inputs=17,
num_outputs=6, num_outputs=6,
max_nodes=200, init_hidden_layers=(),
max_conns=500, node_gene=BiasNode(
node_gene=DefaultNodeGene( activation_options=ACT.tanh,
activation_options=(ACT.standard_tanh,), aggregation_options=AGG.sum,
activation_default=ACT.standard_tanh,
), ),
output_transform=ACT.tanh, output_transform=ACT.tanh,
), ),
pop_size=10000,
species_size=10,
),
) )
evox_algorithm = EvoXAlgorithmAdaptor(neat_algorithm) evox_algorithm = EvoXAlgorithmAdaptor(neat_algorithm)
@@ -37,12 +37,13 @@ problem = problems.neuroevolution.Brax(
policy=evox_algorithm.forward, policy=evox_algorithm.forward,
max_episode_length=1000, max_episode_length=1000,
num_episodes=1, num_episodes=1,
backend="mjx"
) )
def nan2inf(x): def nan2inf(x):
return jnp.where(jnp.isnan(x), -jnp.inf, x) return jnp.where(jnp.isnan(x), -jnp.inf, x)
# create a workflow # create a workflow
workflow = workflows.StdWorkflow( workflow = workflows.StdWorkflow(
algorithm=evox_algorithm, algorithm=evox_algorithm,
@@ -55,11 +56,11 @@ workflow = workflows.StdWorkflow(
# init the workflow # init the workflow
state = workflow.init(workflow_key) state = workflow.init(workflow_key)
# state = workflow.enable_multi_devices(state)
# run the workflow for 100 steps
import time
# enable multi devices
state = workflow.enable_multi_devices(state)
# run the workflow for 100 steps
for i in range(100): for i in range(100):
tic = time.time()
train_info, state = workflow.step(state) train_info, state = workflow.step(state)
monitor.show() monitor.show()

View File

@@ -35,6 +35,9 @@ dependencies = [
"flax >= 0.8.4", "flax >= 0.8.4",
"mujoco >= 3.1.4", "mujoco >= 3.1.4",
"mujoco-mjx >= 3.1.4", "mujoco-mjx >= 3.1.4",
"networkx >= 3.3",
"matplotlib >= 3.9.0",
"sympy >= 1.12.1",
] ]
[project.urls] [project.urls]

View File

@@ -1,180 +0,0 @@
Metadata-Version: 2.1
Name: tensorneat
Version: 0.1.0
Summary: tensorneat
Author-email: Lishuang Wang <wanglishuang22@gmail.com>
License: BSD 3-Clause License
Copyright (c) 2024, EMI-Group
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Project-URL: Homepage, https://github.com/EMI-Group/tensorneat
Project-URL: Bug Tracker, https://github.com/EMI-Group/tensorneat/issues
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: BSD License
Classifier: Intended Audience :: Science/Research
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.9
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: brax>=0.10.3
Requires-Dist: jax>=0.4.28
Requires-Dist: gymnax>=0.0.8
Requires-Dist: jaxopt>=0.8.3
Requires-Dist: optax>=0.2.2
Requires-Dist: flax>=0.8.4
Requires-Dist: mujoco>=3.1.4
Requires-Dist: mujoco-mjx>=3.1.4
<h1 align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="./imgs/evox_logo_dark.png">
<source media="(prefers-color-scheme: light)" srcset="./imgs/evox_logo_light.png">
<a href="https://github.com/EMI-Group/evox">
<img alt="EvoX Logo" height="50" src="./imgs/evox_logo_light.png">
</a>
</picture>
<br>
</h1>
<p align="center">
🌟 TensorNEAT: Tensorized NEAT Implementation in JAX 🌟
</p>
<p align="center">
<a href="https://arxiv.org/abs/2404.01817">
<img src="https://img.shields.io/badge/paper-arxiv-red?style=for-the-badge" alt="TensorNEAT Paper on arXiv">
</a>
</p>
## Introduction
TensorNEAT is a JAX-based libaray for NeuroEvolution of Augmenting Topologies (NEAT) algorithms, focused on harnessing GPU acceleration to enhance the efficiency of evolving neural network structures for complex tasks. Its core mechanism involves the tensorization of network topologies, enabling parallel processing and significantly boosting computational speed and scalability by leveraging modern hardware accelerators. TensorNEAT is compatible with the [EvoX](https://github.com/EMI-Group/evox/) framewrok.
## Requirements
Due to the rapid iteration of JAX versions, configuring the runtime environment for TensorNEAT can be challenging. We recommend the following versions for the relevant libraries:
- jax (0.4.28)
- jaxlib (0.4.28+cuda12.cudnn89)
- brax (0.10.3)
- gymnax (0.0.8)
We provide detailed JAX-related environment references in [recommend_environment](recommend_environment.txt). If you encounter any issues while configuring the environment yourself, you can use this as a reference.
## Example
Simple Example for XOR problem:
```python
from pipeline import Pipeline
from algorithm.neat import *
from problem.func_fit import XOR3d
if __name__ == '__main__':
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=3,
num_outputs=1,
max_nodes=50,
max_conns=100,
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
),
),
problem=XOR3d(),
generation_limit=10000,
fitness_target=-1e-8
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best)
```
Simple Example for RL envs in Brax (Ant):
```python
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import BraxEnv
from tensorneat.utils import ACT
if __name__ == '__main__':
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=27,
num_outputs=8,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(ACT.tanh,),
activation_default=ACT.tanh,
)
),
pop_size=1000,
species_size=10,
),
),
problem=BraxEnv(
env_name='ant',
),
generation_limit=10000,
fitness_target=5000
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)
```
more examples are in `tensorneat/examples`.
## Community & Support
- Engage in discussions and share your experiences on [GitHub Discussion Board](https://github.com/EMI-Group/evox/discussions).
- Join our QQ group (ID: 297969717).
## Citing TensorNEAT
If you use TensorNEAT in your research and want to cite it in your work, please use:
```
@article{tensorneat,
title = {{Tensorized} {NeuroEvolution} of {Augmenting} {Topologies} for {GPU} {Acceleration}},
author = {Wang, Lishuang and Zhao, Mengfei and Liu, Enyu and Sun, Kebin and Cheng, Ran},
booktitle = {Proceedings of the Genetic and Evolutionary Computation Conference (GECCO)},
year = {2024}
}

View File

@@ -1,69 +0,0 @@
LICENSE
README.md
pyproject.toml
src/tensorneat/pipeline.py
src/tensorneat.egg-info/PKG-INFO
src/tensorneat.egg-info/SOURCES.txt
src/tensorneat.egg-info/dependency_links.txt
src/tensorneat.egg-info/requires.txt
src/tensorneat.egg-info/top_level.txt
src/tensorneat/algorithm/__init__.py
src/tensorneat/algorithm/base.py
src/tensorneat/algorithm/hyperneat/__init__.py
src/tensorneat/algorithm/hyperneat/hyperneat.py
src/tensorneat/algorithm/hyperneat/substrate/__init__.py
src/tensorneat/algorithm/hyperneat/substrate/base.py
src/tensorneat/algorithm/hyperneat/substrate/default.py
src/tensorneat/algorithm/hyperneat/substrate/full.py
src/tensorneat/algorithm/neat/__init__.py
src/tensorneat/algorithm/neat/neat.py
src/tensorneat/algorithm/neat/species.py
src/tensorneat/common/__init__.py
src/tensorneat/common/graph.py
src/tensorneat/common/state.py
src/tensorneat/common/stateful_class.py
src/tensorneat/common/tools.py
src/tensorneat/common/functions/__init__.py
src/tensorneat/common/functions/act_jnp.py
src/tensorneat/common/functions/act_sympy.py
src/tensorneat/common/functions/agg_jnp.py
src/tensorneat/common/functions/agg_sympy.py
src/tensorneat/common/functions/manager.py
src/tensorneat/genome/__init__.py
src/tensorneat/genome/base.py
src/tensorneat/genome/default.py
src/tensorneat/genome/recurrent.py
src/tensorneat/genome/utils.py
src/tensorneat/genome/gene/__init__.py
src/tensorneat/genome/gene/base.py
src/tensorneat/genome/gene/conn/__init__.py
src/tensorneat/genome/gene/conn/base.py
src/tensorneat/genome/gene/conn/default.py
src/tensorneat/genome/gene/node/__init__.py
src/tensorneat/genome/gene/node/base.py
src/tensorneat/genome/gene/node/bias.py
src/tensorneat/genome/gene/node/default.py
src/tensorneat/genome/operations/__init__.py
src/tensorneat/genome/operations/crossover/__init__.py
src/tensorneat/genome/operations/crossover/base.py
src/tensorneat/genome/operations/crossover/default.py
src/tensorneat/genome/operations/distance/__init__.py
src/tensorneat/genome/operations/distance/base.py
src/tensorneat/genome/operations/distance/default.py
src/tensorneat/genome/operations/mutation/__init__.py
src/tensorneat/genome/operations/mutation/base.py
src/tensorneat/genome/operations/mutation/default.py
src/tensorneat/problem/__init__.py
src/tensorneat/problem/base.py
src/tensorneat/problem/func_fit/__init__.py
src/tensorneat/problem/func_fit/custom.py
src/tensorneat/problem/func_fit/func_fit.py
src/tensorneat/problem/func_fit/xor.py
src/tensorneat/problem/func_fit/xor3d.py
src/tensorneat/problem/rl/__init__.py
src/tensorneat/problem/rl/brax.py
src/tensorneat/problem/rl/gymnax.py
src/tensorneat/problem/rl/rl_jit.py
test/test_genome.py
test/test_nan_fitness.py
test/test_record_episode.py

View File

@@ -1,8 +0,0 @@
brax>=0.10.3
jax>=0.4.28
gymnax>=0.0.8
jaxopt>=0.8.3
optax>=0.2.2
flax>=0.8.4
mujoco>=3.1.4
mujoco-mjx>=3.1.4

View File

@@ -1 +0,0 @@
tensorneat

View File

@@ -19,7 +19,7 @@ class HyperNEAT(BaseAlgorithm):
aggregation: Callable = AGG.sum, aggregation: Callable = AGG.sum,
activation: Callable = ACT.sigmoid, activation: Callable = ACT.sigmoid,
activate_time: int = 10, activate_time: int = 10,
output_transform: Callable = ACT.standard_sigmoid, output_transform: Callable = ACT.sigmoid,
): ):
assert ( assert (
substrate.query_coors.shape[1] == neat.num_inputs substrate.query_coors.shape[1] == neat.num_inputs

View File

@@ -3,4 +3,4 @@ from .graph import *
from .state import State from .state import State
from .stateful_class import StatefulBaseClass from .stateful_class import StatefulBaseClass
from .functions import ACT, AGG, apply_activation, apply_aggregation from .functions import ACT, AGG, apply_activation, apply_aggregation, get_func_name

View File

@@ -0,0 +1,2 @@
from .algorithm_adaptor import EvoXAlgorithmAdaptor
from .tensorneat_monitor import TensorNEATMonitor

View File

@@ -16,12 +16,12 @@ class TensorNEATMonitor(Monitor):
def __init__( def __init__(
self, self,
neat_algorithm: TensorNEATAlgorithm, tensorneat_algorithm: TensorNEATAlgorithm,
save_dir: str = None, save_dir: str = None,
is_save: bool = False, is_save: bool = False,
): ):
super().__init__() super().__init__()
self.neat_algorithm = neat_algorithm self.tensorneat_algorithm = tensorneat_algorithm
self.generation_timestamp = time.time() self.generation_timestamp = time.time()
self.alg_state: TensorNEATState = None self.alg_state: TensorNEATState = None
@@ -60,7 +60,9 @@ class TensorNEATMonitor(Monitor):
self.fitness = jax.device_get(fitness) self.fitness = jax.device_get(fitness)
def show(self): def show(self):
pop = self.neat_algorithm.ask(self.alg_state) pop = self.tensorneat_algorithm.ask(self.alg_state)
generation = int(self.alg_state.generation)
valid_fitnesses = self.fitness[~np.isinf(self.fitness)] valid_fitnesses = self.fitness[~np.isinf(self.fitness)]
max_f, min_f, mean_f, std_f = ( max_f, min_f, mean_f, std_f = (
@@ -73,7 +75,7 @@ class TensorNEATMonitor(Monitor):
new_timestamp = time.time() new_timestamp = time.time()
cost_time = new_timestamp - self.generation_timestamp cost_time = new_timestamp - self.generation_timestamp
self.generation_timestamp = new_timestamp self.generation_timestamp = time.time()
max_idx = np.argmax(self.fitness) max_idx = np.argmax(self.fitness)
if self.fitness[max_idx] > self.best_fitness: if self.fitness[max_idx] > self.best_fitness:
@@ -81,14 +83,12 @@ class TensorNEATMonitor(Monitor):
self.best_genome = pop[0][max_idx], pop[1][max_idx] self.best_genome = pop[0][max_idx], pop[1][max_idx]
if self.is_save: if self.is_save:
# save best
best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx])) best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx]))
with open( file_name = os.path.join(
os.path.join( self.genome_dir, f"{generation}.npz"
self.genome_dir, )
f"{int(self.neat_algorithm.generation(self.alg_state))}.npz", with open(file_name, "wb") as f:
),
"wb",
) as f:
np.savez( np.savez(
f, f,
nodes=best_genome[0], nodes=best_genome[0],
@@ -96,38 +96,15 @@ class TensorNEATMonitor(Monitor):
fitness=self.best_fitness, fitness=self.best_fitness,
) )
# save best if save path is not None # append log
member_count = jax.device_get(self.neat_algorithm.member_count(self.alg_state)) with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
species_sizes = [int(i) for i in member_count if i > 0] f.write(
f"{generation},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
pop = jax.device_get(pop)
pop_nodes, pop_conns = pop # (P, N, NL), (P, C, CL)
nodes_cnt = (~np.isnan(pop_nodes[:, :, 0])).sum(axis=1) # (P,)
conns_cnt = (~np.isnan(pop_conns[:, :, 0])).sum(axis=1) # (P,)
max_node_cnt, min_node_cnt, mean_node_cnt = (
max(nodes_cnt),
min(nodes_cnt),
np.mean(nodes_cnt),
)
max_conn_cnt, min_conn_cnt, mean_conn_cnt = (
max(conns_cnt),
min(conns_cnt),
np.mean(conns_cnt),
) )
print( print(
f"Generation: {self.neat_algorithm.generation(self.alg_state)}, Cost time: {cost_time * 1000:.2f}ms\n", f"Generation: {generation}, Cost time: {cost_time * 1000:.2f}ms\n",
f"\tnode counts: max: {max_node_cnt}, min: {min_node_cnt}, mean: {mean_node_cnt:.2f}\n",
f"\tconn counts: max: {max_conn_cnt}, min: {min_conn_cnt}, mean: {mean_conn_cnt:.2f}\n",
f"\tspecies: {len(species_sizes)}, {species_sizes}\n",
f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n", f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
) )
# append log self.tensorneat_algorithm.show_details(self.alg_state, self.fitness)
if self.is_save:
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
f.write(
f"{self.neat_algorithm.generation(self.alg_state)},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
)

View File

@@ -1,3 +1,5 @@
import jax, jax.numpy as jnp
from .act_jnp import * from .act_jnp import *
from .act_sympy import * from .act_sympy import *
from .agg_jnp import * from .agg_jnp import *
@@ -32,6 +34,7 @@ act_name2sympy = {
"log": SympyLog, "log": SympyLog,
"exp": SympyExp, "exp": SympyExp,
"abs": SympyAbs, "abs": SympyAbs,
"clip": SympyClip,
} }
agg_name2jnp = { agg_name2jnp = {
@@ -40,7 +43,6 @@ agg_name2jnp = {
"max": max_, "max": max_,
"min": min_, "min": min_,
"maxabs": maxabs_, "maxabs": maxabs_,
"median": median_,
"mean": mean_, "mean": mean_,
} }
@@ -50,9 +52,42 @@ agg_name2sympy = {
"max": SympyMax, "max": SympyMax,
"min": SympyMin, "min": SympyMin,
"maxabs": SympyMaxabs, "maxabs": SympyMaxabs,
"median": SympyMedian,
"mean": SympyMean, "mean": SympyMean,
} }
ACT = FunctionManager(act_name2jnp, act_name2sympy) ACT = FunctionManager(act_name2jnp, act_name2sympy)
AGG = FunctionManager(agg_name2jnp, agg_name2sympy) AGG = FunctionManager(agg_name2jnp, agg_name2sympy)
def apply_activation(idx, z, act_funcs):
"""
calculate activation function for each node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
# -1 means identity activation
res = jax.lax.cond(
idx == -1,
lambda: z,
lambda: jax.lax.switch(idx, act_funcs, z),
)
return res
def apply_aggregation(idx, z, agg_funcs):
"""
calculate activation function for inputs of node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
return jax.lax.cond(
jnp.all(jnp.isnan(z)),
lambda: jnp.nan, # all inputs are nan
lambda: jax.lax.switch(idx, agg_funcs, z), # otherwise
)
def get_func_name(func):
name = func.__name__
if name.endswith("_"):
name = name[:-1]
return name

View File

@@ -1,6 +1,6 @@
import jax.numpy as jnp import jax.numpy as jnp
SCALE = 5 SCALE = 3
def scaled_sigmoid_(z): def scaled_sigmoid_(z):

View File

@@ -1,7 +1,7 @@
import sympy as sp import sympy as sp
import numpy as np import numpy as np
SCALE = 5 SCALE = 3
class SympySigmoid(sp.Function): class SympySigmoid(sp.Function):
@classmethod @classmethod

View File

@@ -23,18 +23,6 @@ def maxabs_(z):
max_abs_index = jnp.argmax(abs_z) max_abs_index = jnp.argmax(abs_z)
return z[max_abs_index] return z[max_abs_index]
def median_(z):
n = jnp.sum(~jnp.isnan(z), axis=0)
z = jnp.sort(z) # sort
idx1, idx2 = (n - 1) // 2, n // 2
median = (z[idx1] + z[idx2]) / 2
return median
def mean_(z): def mean_(z):
sumation = sum_(z) sumation = sum_(z)
valid_count = jnp.sum(~jnp.isnan(z), axis=0) valid_count = jnp.sum(~jnp.isnan(z), axis=0)

View File

@@ -7,30 +7,18 @@ class SympySum(sp.Function):
def eval(cls, z): def eval(cls, z):
return sp.Add(*z) return sp.Add(*z)
@classmethod
def numerical_eval(cls, z, backend=np):
return backend.sum(z)
class SympyProduct(sp.Function): class SympyProduct(sp.Function):
@classmethod @classmethod
def eval(cls, z): def eval(cls, z):
return sp.Mul(*z) return sp.Mul(*z)
@classmethod
def numerical_eval(cls, z, backend=np):
return backend.product(z)
class SympyMax(sp.Function): class SympyMax(sp.Function):
@classmethod @classmethod
def eval(cls, z): def eval(cls, z):
return sp.Max(*z) return sp.Max(*z)
@classmethod
def numerical_eval(cls, z, backend=np):
return backend.max(z)
class SympyMin(sp.Function): class SympyMin(sp.Function):
@classmethod @classmethod
@@ -48,26 +36,3 @@ class SympyMean(sp.Function):
@classmethod @classmethod
def eval(cls, z): def eval(cls, z):
return sp.Add(*z) / len(z) return sp.Add(*z) / len(z)
class SympyMedian(sp.Function):
@classmethod
def eval(cls, args):
if all(arg.is_number for arg in args):
sorted_args = sorted(args)
n = len(sorted_args)
if n % 2 == 1:
return sorted_args[n // 2]
else:
return (sorted_args[n // 2 - 1] + sorted_args[n // 2]) / 2
return None
def _sympystr(self, printer):
return f"median({', '.join(map(str, self.args))})"
def _latex(self, printer):
return (
r"\mathrm{median}\left(" + ", ".join(map(sp.latex, self.args)) + r"\right)"
)

View File

@@ -1,28 +1,32 @@
from functools import partial
import numpy as np
import jax.numpy as jnp
from typing import Union, Callable from typing import Union, Callable
import sympy as sp import sympy as sp
class FunctionManager: class FunctionManager:
def __init__(self, name2jnp, name2sympy): def __init__(self, name2jnp, name2sympy):
self.name2jnp = name2jnp self.name2jnp = name2jnp
self.name2sympy = name2sympy self.name2sympy = name2sympy
for name, func in name2jnp.items():
setattr(self, name, func)
def get_all_funcs(self): def get_all_funcs(self):
all_funcs = [] all_funcs = []
for name in self.names: for name in self.name2jnp:
all_funcs.append(getattr(self, name)) all_funcs.append(getattr(self, name))
return all_funcs return all_funcs
def __getattribute__(self, name: str):
return self.name2jnp[name]
def add_func(self, name, func): def add_func(self, name, func):
if not callable(func): if not callable(func):
raise ValueError("The provided function is not callable") raise ValueError("The provided function is not callable")
if name in self.names: if name in self.name2jnp:
raise ValueError(f"The provided name={name} is already in use") raise ValueError(f"The provided name={name} is already in use")
self.name2jnp[name] = func self.name2jnp[name] = func
setattr(self, name, func)
def update_sympy(self, name, sympy_cls: sp.Function): def update_sympy(self, name, sympy_cls: sp.Function):
self.name2sympy[name] = sympy_cls self.name2sympy[name] = sympy_cls
@@ -47,3 +51,16 @@ class FunctionManager:
if name not in self.name2sympy: if name not in self.name2sympy:
raise ValueError(f"Func {name} doesn't have a sympy representation.") raise ValueError(f"Func {name} doesn't have a sympy representation.")
return self.name2sympy[name] return self.name2sympy[name]
def sympy_module(self, backend: str):
assert backend in ["jax", "numpy"]
if backend == "jax":
backend = jnp
elif backend == "numpy":
backend = np
module = {}
for sympy_cls in self.name2sympy.values():
if hasattr(sympy_cls, "numerical_eval"):
module[sympy_cls.__name__] = partial(sympy_cls.numerical_eval, backend)
return module

View File

@@ -15,8 +15,8 @@ from tensorneat.common import (
topological_sort_python, topological_sort_python,
I_INF, I_INF,
attach_with_inf, attach_with_inf,
SYMPY_FUNCS_MODULE_NP, ACT,
SYMPY_FUNCS_MODULE_JNP, AGG
) )
@@ -92,7 +92,9 @@ class DefaultGenome(BaseGenome):
def otherwise(): def otherwise():
# calculate connections # calculate connections
conn_indices = u_conns[:, i] conn_indices = u_conns[:, i]
hit_attrs = attach_with_inf(conns_attrs, conn_indices) # fetch conn attrs hit_attrs = attach_with_inf(
conns_attrs, conn_indices
) # fetch conn attrs
ins = vmap(self.conn_gene.forward, in_axes=(None, 0, 0))( ins = vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
state, hit_attrs, values state, hit_attrs, values
) )
@@ -102,7 +104,9 @@ class DefaultGenome(BaseGenome):
state, state,
nodes_attrs[i], nodes_attrs[i],
ins, ins,
is_output_node=jnp.isin(nodes[i, 0], self.output_idx), # nodes[0] -> the key of nodes is_output_node=jnp.isin(
nodes[i, 0], self.output_idx
), # nodes[0] -> the key of nodes
) )
# set new value # set new value
@@ -139,7 +143,6 @@ class DefaultGenome(BaseGenome):
): ):
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'" assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP
if sympy_input_transform is None and self.input_transform is not None: if sympy_input_transform is None and self.input_transform is not None:
warnings.warn( warnings.warn(
@@ -224,7 +227,7 @@ class DefaultGenome(BaseGenome):
sp.lambdify( sp.lambdify(
input_symbols + list(args_symbols.keys()), input_symbols + list(args_symbols.keys()),
exprs, exprs,
modules=[backend, module], modules=[backend, AGG.sympy_module(backend), ACT.sympy_module(backend)],
) )
for exprs in output_exprs for exprs in output_exprs
] ]
@@ -256,7 +259,12 @@ class DefaultGenome(BaseGenome):
rotate=0, rotate=0,
reverse_node_order=False, reverse_node_order=False,
size=(300, 300, 300), size=(300, 300, 300),
color=("blue", "blue", "blue"), color=("yellow", "white", "blue"),
with_labels=False,
edgecolors="k",
arrowstyle="->",
arrowsize=3,
edge_color=(0.3, 0.3, 0.3),
save_path="network.svg", save_path="network.svg",
save_dpi=800, save_dpi=800,
**kwargs, **kwargs,
@@ -264,7 +272,6 @@ class DefaultGenome(BaseGenome):
import networkx as nx import networkx as nx
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
nodes_list = list(network["nodes"])
conns_list = list(network["conns"]) conns_list = list(network["conns"])
input_idx = self.get_input_idx() input_idx = self.get_input_idx()
output_idx = self.get_output_idx() output_idx = self.get_output_idx()
@@ -316,6 +323,11 @@ class DefaultGenome(BaseGenome):
pos=rotated_pos, pos=rotated_pos,
node_size=node_sizes, node_size=node_sizes,
node_color=node_colors, node_color=node_colors,
with_labels=with_labels,
edgecolors=edgecolors,
arrowstyle=arrowstyle,
arrowsize=arrowsize,
edge_color=edge_color,
**kwargs, **kwargs,
) )
plt.savefig(save_path, dpi=save_dpi) plt.savefig(save_path, dpi=save_dpi)

View File

@@ -10,7 +10,7 @@ from tensorneat.common import (
apply_aggregation, apply_aggregation,
mutate_int, mutate_int,
mutate_float, mutate_float,
convert_to_sympy, get_func_name
) )
from . import BaseNode from . import BaseNode
@@ -141,8 +141,8 @@ class BiasNode(BaseNode):
self.__class__.__name__, self.__class__.__name__,
idx, idx,
bias, bias,
self.aggregation_options[agg].__name__, get_func_name(self.aggregation_options[agg]),
act_func.__name__, get_func_name(act_func),
idx_width=idx_width, idx_width=idx_width,
float_width=precision + 3, float_width=precision + 3,
func_width=func_width, func_width=func_width,
@@ -165,21 +165,19 @@ class BiasNode(BaseNode):
return { return {
"idx": idx, "idx": idx,
"bias": bias, "bias": bias,
"agg": self.aggregation_options[int(agg)].__name__, "agg": get_func_name(self.aggregation_options[agg]),
"act": act_func.__name__, "act": get_func_name(act_func),
} }
def sympy_func(self, state, node_dict, inputs, is_output_node=False): def sympy_func(self, state, node_dict, inputs, is_output_node=False):
nd = node_dict bias = sp.symbols(f"n_{node_dict['idx']}_b")
bias = sp.symbols(f"n_{nd['idx']}_b") z = AGG.obtain_sympy(node_dict["agg"])(inputs)
z = convert_to_sympy(nd["agg"])(inputs)
z = bias + z z = bias + z
if is_output_node: if is_output_node:
pass pass
else: else:
z = convert_to_sympy(nd["act"])(z) z = ACT.obtain_sympy(node_dict["act"])(z)
return z, {bias: nd["bias"]} return z, {bias: node_dict["bias"]}

View File

@@ -11,7 +11,7 @@ from tensorneat.common import (
apply_aggregation, apply_aggregation,
mutate_int, mutate_int,
mutate_float, mutate_float,
convert_to_sympy, get_func_name
) )
from .base import BaseNode from .base import BaseNode
@@ -176,8 +176,8 @@ class DefaultNode(BaseNode):
idx, idx,
bias, bias,
res, res,
self.aggregation_options[agg].__name__, get_func_name(self.aggregation_options[agg]),
act_func.__name__, get_func_name(act_func),
idx_width=idx_width, idx_width=idx_width,
float_width=precision + 3, float_width=precision + 3,
func_width=func_width, func_width=func_width,
@@ -200,8 +200,8 @@ class DefaultNode(BaseNode):
"idx": idx, "idx": idx,
"bias": bias, "bias": bias,
"res": res, "res": res,
"agg": self.aggregation_options[int(agg)].__name__, "agg": get_func_name(self.aggregation_options[agg]),
"act": act_func.__name__, "act": get_func_name(act_func),
} }
def sympy_func(self, state, node_dict, inputs, is_output_node=False): def sympy_func(self, state, node_dict, inputs, is_output_node=False):
@@ -209,12 +209,13 @@ class DefaultNode(BaseNode):
bias = sp.symbols(f"n_{nd['idx']}_b") bias = sp.symbols(f"n_{nd['idx']}_b")
res = sp.symbols(f"n_{nd['idx']}_r") res = sp.symbols(f"n_{nd['idx']}_r")
z = convert_to_sympy(nd["agg"])(inputs) print(nd["agg"])
z = AGG.obtain_sympy(nd["agg"])(inputs)
z = bias + res * z z = bias + res * z
if is_output_node: if is_output_node:
pass pass
else: else:
z = convert_to_sympy(nd["act"])(z) z = ACT.obtain_sympy(nd["act"])(z)
return z, {bias: nd["bias"], res: nd["res"]} return z, {bias: nd["bias"], res: nd["res"]}

View File

@@ -0,0 +1,6 @@
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,23 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "286b5fc2-e629-4461-ad95-b69d3d606a78",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "",
"name": ""
},
"language_info": {
"name": ""
}
},
"nbformat": 4,
"nbformat_minor": 5
}