add feedforward substrate and hyperneat and related example;
see https://github.com/EMI-Group/tensorneat/issues/9; fix bugs in genome visualize (add plt.close())
This commit is contained in:
63
examples/func_fit/xor_hyperneat_feedforward.py
Normal file
63
examples/func_fit/xor_hyperneat_feedforward.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from tensorneat.pipeline import Pipeline
|
||||
from tensorneat.algorithm.neat import NEAT
|
||||
from tensorneat.algorithm.hyperneat import HyperNEATFeedForward, MLPSubstrate
|
||||
from tensorneat.genome import DefaultGenome
|
||||
from tensorneat.common import ACT
|
||||
|
||||
from tensorneat.problem.func_fit import XOR3d
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=HyperNEATFeedForward(
|
||||
substrate=MLPSubstrate(
|
||||
layers=[4, 5, 5, 5, 1], coor_range=(-5.0, 5.0, -5.0, 5.0)
|
||||
),
|
||||
neat=NEAT(
|
||||
pop_size=10000,
|
||||
species_size=20,
|
||||
survival_threshold=0.01,
|
||||
genome=DefaultGenome(
|
||||
num_inputs=4, # size of query coors
|
||||
num_outputs=1,
|
||||
init_hidden_layers=(),
|
||||
output_transform=ACT.tanh,
|
||||
),
|
||||
),
|
||||
activation=ACT.tanh,
|
||||
output_transform=ACT.sigmoid,
|
||||
),
|
||||
problem=XOR3d(),
|
||||
generation_limit=1000,
|
||||
fitness_target=-1e-5,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
# show result
|
||||
pipeline.show(state, best)
|
||||
|
||||
# visualize cppn
|
||||
cppn_genome = pipeline.algorithm.neat.genome
|
||||
cppn_network = cppn_genome.network_dict(state, *best)
|
||||
cppn_genome.visualize(cppn_network, save_path="./imgs/cppn_network.svg")
|
||||
|
||||
# visualize hyperneat genome
|
||||
hyperneat_genome = pipeline.algorithm.hyper_genome
|
||||
# use cppn to calculate the weights in hyperneat genome
|
||||
# return seqs, nodes, conns, u_conns
|
||||
_, hyperneat_nodes, hyperneat_conns, _ = pipeline.algorithm.transform(state, best)
|
||||
# mutate the connection with weight 0 (to visualize the network rather the substrate)
|
||||
hyperneat_conns = jnp.where(
|
||||
hyperneat_conns[:, 2][:, None] == 0, jnp.nan, hyperneat_conns
|
||||
)
|
||||
hyperneat_network = hyperneat_genome.network_dict(
|
||||
state, hyperneat_nodes, hyperneat_conns
|
||||
)
|
||||
hyperneat_genome.visualize(
|
||||
hyperneat_network, save_path="./imgs/hyperneat_network.svg"
|
||||
)
|
||||
Reference in New Issue
Block a user