Files
tensorneat-mend/tensorneat/genome/default.py
2024-07-10 16:50:36 +08:00

322 lines
10 KiB
Python

import warnings
import jax
from jax import vmap, numpy as jnp
import numpy as np
import sympy as sp
from .base import BaseGenome
from .gene import DefaultNodeGene, DefaultConnGene
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs
from tensorneat.common import (
topological_sort,
topological_sort_python,
I_INF,
attach_with_inf,
SYMPY_FUNCS_MODULE_NP,
SYMPY_FUNCS_MODULE_JNP,
)
class DefaultGenome(BaseGenome):
"""Default genome class, with the same behavior as the NEAT-Python"""
network_type = "feedforward"
def __init__(
self,
num_inputs: int,
num_outputs: int,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(),
conn_gene=DefaultConnGene(),
mutation=DefaultMutation(),
crossover=DefaultCrossover(),
distance=DefaultDistance(),
output_transform=None,
input_transform=None,
init_hidden_layers=(),
):
super().__init__(
num_inputs,
num_outputs,
max_nodes,
max_conns,
node_gene,
conn_gene,
mutation,
crossover,
distance,
output_transform,
input_transform,
init_hidden_layers,
)
def transform(self, state, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
conn_exist = u_conns != I_INF
seqs = topological_sort(nodes, conn_exist)
return seqs, nodes, conns, u_conns
def forward(self, state, transformed, inputs):
if self.input_transform is not None:
inputs = self.input_transform(inputs)
cal_seqs, nodes, conns, u_conns = transformed
ini_vals = jnp.full((self.max_nodes,), jnp.nan)
ini_vals = ini_vals.at[self.input_idx].set(inputs)
nodes_attrs = vmap(extract_node_attrs)(nodes)
conns_attrs = vmap(extract_conn_attrs)(conns)
def cond_fun(carry):
values, idx = carry
return (idx < self.max_nodes) & (
cal_seqs[idx] != I_INF
) # not out of bounds and next node exists
def body_func(carry):
values, idx = carry
i = cal_seqs[idx]
def input_node():
return values
def otherwise():
# calculate connections
conn_indices = u_conns[:, i]
hit_attrs = attach_with_inf(conns_attrs, conn_indices) # fetch conn attrs
ins = vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
state, hit_attrs, values
)
# calculate nodes
z = self.node_gene.forward(
state,
nodes_attrs[i],
ins,
is_output_node=jnp.isin(nodes[i, 0], self.output_idx), # nodes[0] -> the key of nodes
)
# set new value
new_values = values.at[i].set(z)
return new_values
values = jax.lax.cond(jnp.isin(i, self.input_idx), input_node, otherwise)
return values, idx + 1
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
if self.output_transform is None:
return vals[self.output_idx]
else:
return self.output_transform(vals[self.output_idx])
def network_dict(self, state, nodes, conns):
network = super().network_dict(state, nodes, conns)
topo_order, topo_layers = topological_sort_python(
set(network["nodes"]), set(network["conns"])
)
network["topo_order"] = topo_order
network["topo_layers"] = topo_layers
return network
def sympy_func(
self,
state,
network,
sympy_input_transform=None,
sympy_output_transform=None,
backend="jax",
):
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP
if sympy_input_transform is None and self.input_transform is not None:
warnings.warn(
"genome.input_transform is not None but sympy_input_transform is None!"
)
if sympy_input_transform is None:
sympy_input_transform = lambda x: x
if sympy_input_transform is not None:
if not isinstance(sympy_input_transform, list):
sympy_input_transform = [sympy_input_transform] * self.num_inputs
if sympy_output_transform is None and self.output_transform is not None:
warnings.warn(
"genome.output_transform is not None but sympy_output_transform is None!"
)
input_idx = self.get_input_idx()
output_idx = self.get_output_idx()
order = network["topo_order"]
hidden_idx = [
i for i in network["nodes"] if i not in input_idx and i not in output_idx
]
symbols = {}
for i in network["nodes"]:
if i in input_idx:
symbols[-i - 1] = sp.Symbol(f"i{i - min(input_idx)}") # origin_i
symbols[i] = sp.Symbol(f"norm{i - min(input_idx)}")
elif i in output_idx:
symbols[i] = sp.Symbol(f"o{i - min(output_idx)}")
else: # hidden
symbols[i] = sp.Symbol(f"h{i - min(hidden_idx)}")
nodes_exprs = {}
args_symbols = {}
for i in order:
if i in input_idx:
nodes_exprs[symbols[-i - 1]] = symbols[
-i - 1
] # origin equal to its symbol
nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](
symbols[-i - 1]
) # normed i
else:
in_conns = [c for c in network["conns"] if c[1] == i]
node_inputs = []
for conn in in_conns:
val_represent = symbols[conn[0]]
# a_s -> args_symbols
val, a_s = self.conn_gene.sympy_func(
state,
network["conns"][conn],
val_represent,
)
args_symbols.update(a_s)
node_inputs.append(val)
nodes_exprs[symbols[i]], a_s = self.node_gene.sympy_func(
state,
network["nodes"][i],
node_inputs,
is_output_node=(i in output_idx),
)
args_symbols.update(a_s)
if i in output_idx and sympy_output_transform is not None:
nodes_exprs[symbols[i]] = sympy_output_transform(
nodes_exprs[symbols[i]]
)
input_symbols = [symbols[-i - 1] for i in input_idx]
reduced_exprs = nodes_exprs.copy()
for i in order:
reduced_exprs[symbols[i]] = reduced_exprs[symbols[i]].subs(reduced_exprs)
output_exprs = [reduced_exprs[symbols[i]] for i in output_idx]
lambdify_output_funcs = [
sp.lambdify(
input_symbols + list(args_symbols.keys()),
exprs,
modules=[backend, module],
)
for exprs in output_exprs
]
fixed_args_output_funcs = []
for i in range(len(output_idx)):
def f(inputs, i=i):
return lambdify_output_funcs[i](*inputs, *args_symbols.values())
fixed_args_output_funcs.append(f)
forward_func = lambda inputs: jnp.array(
[f(inputs) for f in fixed_args_output_funcs]
)
return (
symbols,
args_symbols,
input_symbols,
nodes_exprs,
output_exprs,
forward_func,
)
def visualize(
self,
network,
rotate=0,
reverse_node_order=False,
size=(300, 300, 300),
color=("blue", "blue", "blue"),
save_path="network.svg",
save_dpi=800,
**kwargs,
):
import networkx as nx
from matplotlib import pyplot as plt
nodes_list = list(network["nodes"])
conns_list = list(network["conns"])
input_idx = self.get_input_idx()
output_idx = self.get_output_idx()
topo_order, topo_layers = network["topo_order"], network["topo_layers"]
node2layer = {
node: layer for layer, nodes in enumerate(topo_layers) for node in nodes
}
if reverse_node_order:
topo_order = topo_order[::-1]
G = nx.DiGraph()
if not isinstance(size, tuple):
size = (size, size, size)
if not isinstance(color, tuple):
color = (color, color, color)
for node in topo_order:
if node in input_idx:
G.add_node(node, subset=node2layer[node], size=size[0], color=color[0])
elif node in output_idx:
G.add_node(node, subset=node2layer[node], size=size[2], color=color[2])
else:
G.add_node(node, subset=node2layer[node], size=size[1], color=color[1])
for conn in conns_list:
G.add_edge(conn[0], conn[1])
pos = nx.multipartite_layout(G)
def rotate_layout(pos, angle):
angle_rad = np.deg2rad(angle)
cos_angle, sin_angle = np.cos(angle_rad), np.sin(angle_rad)
rotated_pos = {}
for node, (x, y) in pos.items():
rotated_pos[node] = (
cos_angle * x - sin_angle * y,
sin_angle * x + cos_angle * y,
)
return rotated_pos
rotated_pos = rotate_layout(pos, rotate)
node_sizes = [n["size"] for n in G.nodes.values()]
node_colors = [n["color"] for n in G.nodes.values()]
nx.draw(
G,
pos=rotated_pos,
node_size=node_sizes,
node_color=node_colors,
**kwargs,
)
plt.savefig(save_path, dpi=save_dpi)