add feedforward substrate and hyperneat and related example;

see https://github.com/EMI-Group/tensorneat/issues/9;
fix bugs in genome visualize (add plt.close())
This commit is contained in:
wls2002
2024-11-23 15:32:28 +08:00
parent 5fdaf6806c
commit 1b7e49293a
13 changed files with 1565 additions and 1 deletions

View File

@@ -0,0 +1,63 @@
import jax.numpy as jnp
from tensorneat.pipeline import Pipeline
from tensorneat.algorithm.neat import NEAT
from tensorneat.algorithm.hyperneat import HyperNEATFeedForward, MLPSubstrate
from tensorneat.genome import DefaultGenome
from tensorneat.common import ACT
from tensorneat.problem.func_fit import XOR3d
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=HyperNEATFeedForward(
substrate=MLPSubstrate(
layers=[4, 5, 5, 5, 1], coor_range=(-5.0, 5.0, -5.0, 5.0)
),
neat=NEAT(
pop_size=10000,
species_size=20,
survival_threshold=0.01,
genome=DefaultGenome(
num_inputs=4, # size of query coors
num_outputs=1,
init_hidden_layers=(),
output_transform=ACT.tanh,
),
),
activation=ACT.tanh,
output_transform=ACT.sigmoid,
),
problem=XOR3d(),
generation_limit=1000,
fitness_target=-1e-5,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best)
# visualize cppn
cppn_genome = pipeline.algorithm.neat.genome
cppn_network = cppn_genome.network_dict(state, *best)
cppn_genome.visualize(cppn_network, save_path="./imgs/cppn_network.svg")
# visualize hyperneat genome
hyperneat_genome = pipeline.algorithm.hyper_genome
# use cppn to calculate the weights in hyperneat genome
# return seqs, nodes, conns, u_conns
_, hyperneat_nodes, hyperneat_conns, _ = pipeline.algorithm.transform(state, best)
# mutate the connection with weight 0 (to visualize the network rather the substrate)
hyperneat_conns = jnp.where(
hyperneat_conns[:, 2][:, None] == 0, jnp.nan, hyperneat_conns
)
hyperneat_network = hyperneat_genome.network_dict(
state, hyperneat_nodes, hyperneat_conns
)
hyperneat_genome.visualize(
hyperneat_network, save_path="./imgs/hyperneat_network.svg"
)