see https://github.com/EMI-Group/tensorneat/issues/9; fix bugs in genome visualize (add plt.close())
64 lines
2.1 KiB
Python
64 lines
2.1 KiB
Python
import jax.numpy as jnp
|
|
|
|
from tensorneat.pipeline import Pipeline
|
|
from tensorneat.algorithm.neat import NEAT
|
|
from tensorneat.algorithm.hyperneat import HyperNEATFeedForward, MLPSubstrate
|
|
from tensorneat.genome import DefaultGenome
|
|
from tensorneat.common import ACT
|
|
|
|
from tensorneat.problem.func_fit import XOR3d
|
|
|
|
if __name__ == "__main__":
|
|
pipeline = Pipeline(
|
|
algorithm=HyperNEATFeedForward(
|
|
substrate=MLPSubstrate(
|
|
layers=[4, 5, 5, 5, 1], coor_range=(-5.0, 5.0, -5.0, 5.0)
|
|
),
|
|
neat=NEAT(
|
|
pop_size=10000,
|
|
species_size=20,
|
|
survival_threshold=0.01,
|
|
genome=DefaultGenome(
|
|
num_inputs=4, # size of query coors
|
|
num_outputs=1,
|
|
init_hidden_layers=(),
|
|
output_transform=ACT.tanh,
|
|
),
|
|
),
|
|
activation=ACT.tanh,
|
|
output_transform=ACT.sigmoid,
|
|
),
|
|
problem=XOR3d(),
|
|
generation_limit=1000,
|
|
fitness_target=-1e-5,
|
|
)
|
|
|
|
# initialize state
|
|
state = pipeline.setup()
|
|
# print(state)
|
|
# run until terminate
|
|
state, best = pipeline.auto_run(state)
|
|
# show result
|
|
pipeline.show(state, best)
|
|
|
|
# visualize cppn
|
|
cppn_genome = pipeline.algorithm.neat.genome
|
|
cppn_network = cppn_genome.network_dict(state, *best)
|
|
cppn_genome.visualize(cppn_network, save_path="./imgs/cppn_network.svg")
|
|
|
|
# visualize hyperneat genome
|
|
hyperneat_genome = pipeline.algorithm.hyper_genome
|
|
# use cppn to calculate the weights in hyperneat genome
|
|
# return seqs, nodes, conns, u_conns
|
|
_, hyperneat_nodes, hyperneat_conns, _ = pipeline.algorithm.transform(state, best)
|
|
# mutate the connection with weight 0 (to visualize the network rather the substrate)
|
|
hyperneat_conns = jnp.where(
|
|
hyperneat_conns[:, 2][:, None] == 0, jnp.nan, hyperneat_conns
|
|
)
|
|
hyperneat_network = hyperneat_genome.network_dict(
|
|
state, hyperneat_nodes, hyperneat_conns
|
|
)
|
|
hyperneat_genome.visualize(
|
|
hyperneat_network, save_path="./imgs/hyperneat_network.svg"
|
|
)
|