Files
tensorneat-mend/examples/func_fit/xor_hyperneat_feedforward.py
2024-11-23 15:32:28 +08:00

64 lines
2.1 KiB
Python

import jax.numpy as jnp
from tensorneat.pipeline import Pipeline
from tensorneat.algorithm.neat import NEAT
from tensorneat.algorithm.hyperneat import HyperNEATFeedForward, MLPSubstrate
from tensorneat.genome import DefaultGenome
from tensorneat.common import ACT
from tensorneat.problem.func_fit import XOR3d
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=HyperNEATFeedForward(
substrate=MLPSubstrate(
layers=[4, 5, 5, 5, 1], coor_range=(-5.0, 5.0, -5.0, 5.0)
),
neat=NEAT(
pop_size=10000,
species_size=20,
survival_threshold=0.01,
genome=DefaultGenome(
num_inputs=4, # size of query coors
num_outputs=1,
init_hidden_layers=(),
output_transform=ACT.tanh,
),
),
activation=ACT.tanh,
output_transform=ACT.sigmoid,
),
problem=XOR3d(),
generation_limit=1000,
fitness_target=-1e-5,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best)
# visualize cppn
cppn_genome = pipeline.algorithm.neat.genome
cppn_network = cppn_genome.network_dict(state, *best)
cppn_genome.visualize(cppn_network, save_path="./imgs/cppn_network.svg")
# visualize hyperneat genome
hyperneat_genome = pipeline.algorithm.hyper_genome
# use cppn to calculate the weights in hyperneat genome
# return seqs, nodes, conns, u_conns
_, hyperneat_nodes, hyperneat_conns, _ = pipeline.algorithm.transform(state, best)
# mutate the connection with weight 0 (to visualize the network rather the substrate)
hyperneat_conns = jnp.where(
hyperneat_conns[:, 2][:, None] == 0, jnp.nan, hyperneat_conns
)
hyperneat_network = hyperneat_genome.network_dict(
state, hyperneat_nodes, hyperneat_conns
)
hyperneat_genome.visualize(
hyperneat_network, save_path="./imgs/hyperneat_network.svg"
)