odify genome for the official release

This commit is contained in:
root
2024-07-10 11:24:11 +08:00
parent 075460f896
commit ee8ec84202
83 changed files with 588 additions and 611 deletions

View File

@@ -1,25 +1,23 @@
import warnings
from typing import Callable
import jax, jax.numpy as jnp
import jax
from jax import vmap, numpy as jnp
import numpy as np
import sympy as sp
from utils import (
unflatten_conns,
from . import BaseGenome
from ..gene import DefaultNodeGene, DefaultConnGene
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs
from tensorneat.common import (
topological_sort,
topological_sort_python,
I_INF,
extract_node_attrs,
extract_conn_attrs,
set_node_attrs,
set_conn_attrs,
attach_with_inf,
SYMPY_FUNCS_MODULE_NP,
SYMPY_FUNCS_MODULE_JNP,
)
from . import BaseGenome
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover
class DefaultGenome(BaseGenome):
@@ -31,15 +29,18 @@ class DefaultGenome(BaseGenome):
self,
num_inputs: int,
num_outputs: int,
max_nodes=5,
max_conns=4,
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
mutation: BaseMutation = DefaultMutation(),
crossover: BaseCrossover = DefaultCrossover(),
output_transform: Callable = None,
input_transform: Callable = None,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(),
conn_gene=DefaultConnGene(),
mutation=DefaultMutation(),
crossover=DefaultCrossover(),
distance=DefaultDistance(),
output_transform=None,
input_transform=None,
init_hidden_layers=(),
):
super().__init__(
num_inputs,
num_outputs,
@@ -49,22 +50,12 @@ class DefaultGenome(BaseGenome):
conn_gene,
mutation,
crossover,
distance,
output_transform,
input_transform,
init_hidden_layers,
)
if input_transform is not None:
try:
_ = input_transform(np.zeros(num_inputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
self.input_transform = input_transform
if output_transform is not None:
try:
_ = output_transform(np.zeros(num_outputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
self.output_transform = output_transform
def transform(self, state, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
conn_exist = u_conns != I_INF
@@ -73,10 +64,6 @@ class DefaultGenome(BaseGenome):
return seqs, nodes, conns, u_conns
def restore(self, state, transformed):
seqs, nodes, conns, u_conns = transformed
return nodes, conns
def forward(self, state, transformed, inputs):
if self.input_transform is not None:
@@ -86,8 +73,8 @@ class DefaultGenome(BaseGenome):
ini_vals = jnp.full((self.max_nodes,), jnp.nan)
ini_vals = ini_vals.at[self.input_idx].set(inputs)
nodes_attrs = jax.vmap(extract_node_attrs)(nodes)
conns_attrs = jax.vmap(extract_conn_attrs)(conns)
nodes_attrs = vmap(extract_node_attrs)(nodes)
conns_attrs = vmap(extract_conn_attrs)(conns)
def cond_fun(carry):
values, idx = carry
@@ -105,7 +92,7 @@ class DefaultGenome(BaseGenome):
def otherwise():
conn_indices = u_conns[:, i]
hit_attrs = attach_with_inf(conns_attrs, conn_indices)
ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
ins = vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
state, hit_attrs, values
)
@@ -130,85 +117,14 @@ class DefaultGenome(BaseGenome):
else:
return self.output_transform(vals[self.output_idx])
def update_by_batch(self, state, batch_input, transformed):
if self.input_transform is not None:
batch_input = jax.vmap(self.input_transform)(batch_input)
cal_seqs, nodes, conns, u_conns = transformed
batch_size = batch_input.shape[0]
batch_ini_vals = jnp.full((batch_size, self.max_nodes), jnp.nan)
batch_ini_vals = batch_ini_vals.at[:, self.input_idx].set(batch_input)
nodes_attrs = jax.vmap(extract_node_attrs)(nodes)
conns_attrs = jax.vmap(extract_conn_attrs)(conns)
def cond_fun(carry):
batch_values, nodes_attrs_, conns_attrs_, idx = carry
return (idx < self.max_nodes) & (cal_seqs[idx] != I_INF)
def body_func(carry):
batch_values, nodes_attrs_, conns_attrs_, idx = carry
i = cal_seqs[idx]
def input_node():
batch, new_attrs = self.node_gene.update_input_transform(
state, nodes_attrs_[i], batch_values[:, i]
)
return (
batch_values.at[:, i].set(batch),
nodes_attrs_.at[i].set(new_attrs),
conns_attrs_,
)
def otherwise():
conn_indices = u_conns[:, i]
hit_attrs = attach_with_inf(conns_attrs, conn_indices)
batch_ins, new_conn_attrs = jax.vmap(
self.conn_gene.update_by_batch,
in_axes=(None, 0, 1),
out_axes=(1, 0),
)(state, hit_attrs, batch_values)
batch_z, new_node_attrs = self.node_gene.update_by_batch(
state,
nodes_attrs_[i],
batch_ins,
is_output_node=jnp.isin(i, self.output_idx),
)
return (
batch_values.at[:, i].set(batch_z),
nodes_attrs_.at[i].set(new_node_attrs),
conns_attrs_.at[conn_indices].set(new_conn_attrs),
)
# the val of input nodes is obtained by the task, not by calculation
(batch_values, nodes_attrs_, conns_attrs_) = jax.lax.cond(
jnp.isin(i, self.input_idx),
input_node,
otherwise,
)
return batch_values, nodes_attrs_, conns_attrs_, idx + 1
batch_vals, nodes_attrs, conns_attrs, _ = jax.lax.while_loop(
cond_fun, body_func, (batch_ini_vals, nodes_attrs, conns_attrs, 0)
def network_dict(self, state, nodes, conns):
network = super().network_dict(state, nodes, conns)
topo_order, topo_layers = topological_sort_python(
set(network["nodes"]), set(network["conns"])
)
nodes = jax.vmap(set_node_attrs)(nodes, nodes_attrs)
conns = jax.vmap(set_conn_attrs)(conns, conns_attrs)
new_transformed = (cal_seqs, nodes, conns, u_conns)
if self.output_transform is None:
return batch_vals[:, self.output_idx], new_transformed
else:
return (
jax.vmap(self.output_transform)(batch_vals[:, self.output_idx]),
new_transformed,
)
network["topo_order"] = topo_order
network["topo_layers"] = topo_layers
return network
def sympy_func(
self,
@@ -241,7 +157,8 @@ class DefaultGenome(BaseGenome):
input_idx = self.get_input_idx()
output_idx = self.get_output_idx()
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
order = network["topo_order"]
hidden_idx = [
i for i in network["nodes"] if i not in input_idx and i not in output_idx
]
@@ -260,8 +177,12 @@ class DefaultGenome(BaseGenome):
for i in order:
if i in input_idx:
nodes_exprs[symbols[-i - 1]] = symbols[-i - 1] # origin equal to its symbol
nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](symbols[-i - 1]) # normed i
nodes_exprs[symbols[-i - 1]] = symbols[
-i - 1
] # origin equal to its symbol
nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](
symbols[-i - 1]
) # normed i
else:
in_conns = [c for c in network["conns"] if c[1] == i]
@@ -325,3 +246,73 @@ class DefaultGenome(BaseGenome):
output_exprs,
forward_func,
)
def visualize(
self,
network,
rotate=0,
reverse_node_order=False,
size=(300, 300, 300),
color=("blue", "blue", "blue"),
save_path="network.svg",
save_dpi=800,
**kwargs,
):
import networkx as nx
from matplotlib import pyplot as plt
nodes_list = list(network["nodes"])
conns_list = list(network["conns"])
input_idx = self.get_input_idx()
output_idx = self.get_output_idx()
topo_order, topo_layers = network["topo_order"], network["topo_layers"]
node2layer = {
node: layer for layer, nodes in enumerate(topo_layers) for node in nodes
}
if reverse_node_order:
topo_order = topo_order[::-1]
G = nx.DiGraph()
if not isinstance(size, tuple):
size = (size, size, size)
if not isinstance(color, tuple):
color = (color, color, color)
for node in topo_order:
if node in input_idx:
G.add_node(node, subset=node2layer[node], size=size[0], color=color[0])
elif node in output_idx:
G.add_node(node, subset=node2layer[node], size=size[2], color=color[2])
else:
G.add_node(node, subset=node2layer[node], size=size[1], color=color[1])
for conn in conns_list:
G.add_edge(conn[0], conn[1])
pos = nx.multipartite_layout(G)
def rotate_layout(pos, angle):
angle_rad = np.deg2rad(angle)
cos_angle, sin_angle = np.cos(angle_rad), np.sin(angle_rad)
rotated_pos = {}
for node, (x, y) in pos.items():
rotated_pos[node] = (
cos_angle * x - sin_angle * y,
sin_angle * x + cos_angle * y,
)
return rotated_pos
rotated_pos = rotate_layout(pos, rotate)
node_sizes = [n["size"] for n in G.nodes.values()]
node_colors = [n["color"] for n in G.nodes.values()]
nx.draw(
G,
pos=rotated_pos,
node_size=node_sizes,
node_color=node_colors,
**kwargs,
)
plt.savefig(save_path, dpi=save_dpi)