265 lines
8.9 KiB
Python
265 lines
8.9 KiB
Python
import numpy as np
|
|
import jax, jax.numpy as jnp
|
|
from ..gene import BaseNodeGene, BaseConnGene
|
|
from ..ga import BaseMutation, BaseCrossover
|
|
from utils import State, StatefulBaseClass, topological_sort_python
|
|
|
|
|
|
class BaseGenome(StatefulBaseClass):
|
|
network_type = None
|
|
|
|
def __init__(
|
|
self,
|
|
num_inputs: int,
|
|
num_outputs: int,
|
|
max_nodes: int,
|
|
max_conns: int,
|
|
node_gene: BaseNodeGene,
|
|
conn_gene: BaseConnGene,
|
|
mutation: BaseMutation,
|
|
crossover: BaseCrossover,
|
|
):
|
|
self.num_inputs = num_inputs
|
|
self.num_outputs = num_outputs
|
|
self.input_idx = np.arange(num_inputs)
|
|
self.output_idx = np.arange(num_inputs, num_inputs + num_outputs)
|
|
self.max_nodes = max_nodes
|
|
self.max_conns = max_conns
|
|
self.node_gene = node_gene
|
|
self.conn_gene = conn_gene
|
|
self.mutation = mutation
|
|
self.crossover = crossover
|
|
|
|
def setup(self, state=State()):
|
|
state = self.node_gene.setup(state)
|
|
state = self.conn_gene.setup(state)
|
|
state = self.mutation.setup(state)
|
|
state = self.crossover.setup(state)
|
|
return state
|
|
|
|
def transform(self, state, nodes, conns):
|
|
raise NotImplementedError
|
|
|
|
def restore(self, state, transformed):
|
|
raise NotImplementedError
|
|
|
|
def forward(self, state, transformed, inputs):
|
|
raise NotImplementedError
|
|
|
|
def execute_mutation(self, state, randkey, nodes, conns, new_node_key):
|
|
return self.mutation(state, randkey, self, nodes, conns, new_node_key)
|
|
|
|
def execute_crossover(self, state, randkey, nodes1, conns1, nodes2, conns2):
|
|
return self.crossover(state, randkey, self, nodes1, conns1, nodes2, conns2)
|
|
|
|
def initialize(self, state, randkey):
|
|
"""
|
|
Default initialization method for the genome.
|
|
Add an extra hidden node.
|
|
Make all input nodes and output nodes connected to the hidden node.
|
|
All attributes will be initialized randomly using gene.new_random_attrs method.
|
|
|
|
For example, a network with 2 inputs and 1 output, the structure will be:
|
|
nodes:
|
|
[
|
|
[0, attrs0], # input node 0
|
|
[1, attrs1], # input node 1
|
|
[2, attrs2], # output node 0
|
|
[3, attrs3], # hidden node
|
|
[NaN, NaN], # empty node
|
|
]
|
|
conns:
|
|
[
|
|
[0, 3, attrs0], # input node 0 -> hidden node
|
|
[1, 3, attrs1], # input node 1 -> hidden node
|
|
[3, 2, attrs2], # hidden node -> output node 0
|
|
[NaN, NaN],
|
|
[NaN, NaN],
|
|
]
|
|
"""
|
|
|
|
k1, k2 = jax.random.split(randkey) # k1 for nodes, k2 for conns
|
|
# initialize nodes
|
|
new_node_key = (
|
|
max([*self.input_idx, *self.output_idx]) + 1
|
|
) # the key for the hidden node
|
|
node_keys = jnp.concatenate(
|
|
[self.input_idx, self.output_idx, jnp.array([new_node_key])]
|
|
) # the list of all node keys
|
|
|
|
# initialize nodes and connections with NaN
|
|
nodes = jnp.full((self.max_nodes, self.node_gene.length), jnp.nan)
|
|
conns = jnp.full((self.max_conns, self.conn_gene.length), jnp.nan)
|
|
|
|
# set keys for input nodes, output nodes and hidden node
|
|
nodes = nodes.at[node_keys, 0].set(node_keys)
|
|
|
|
# generate random attributes for nodes
|
|
node_keys = jax.random.split(k1, len(node_keys))
|
|
random_node_attrs = jax.vmap(
|
|
self.node_gene.new_random_attrs, in_axes=(None, 0)
|
|
)(state, node_keys)
|
|
nodes = nodes.at[: len(node_keys), 1:].set(random_node_attrs)
|
|
|
|
# initialize conns
|
|
# input-hidden connections
|
|
input_conns = jnp.c_[
|
|
self.input_idx, jnp.full_like(self.input_idx, new_node_key)
|
|
]
|
|
conns = conns.at[self.input_idx, :2].set(input_conns) # in-keys, out-keys
|
|
|
|
# output-hidden connections
|
|
output_conns = jnp.c_[
|
|
jnp.full_like(self.output_idx, new_node_key), self.output_idx
|
|
]
|
|
conns = conns.at[self.output_idx, :2].set(output_conns) # in-keys, out-keys
|
|
|
|
conn_keys = jax.random.split(k2, num=len(self.input_idx) + len(self.output_idx))
|
|
# generate random attributes for conns
|
|
random_conn_attrs = jax.vmap(
|
|
self.conn_gene.new_random_attrs, in_axes=(None, 0)
|
|
)(state, conn_keys)
|
|
conns = conns.at[: len(conn_keys), 2:].set(random_conn_attrs)
|
|
|
|
return nodes, conns
|
|
|
|
def update_by_batch(self, state, batch_input, transformed):
|
|
"""
|
|
Update the genome by a batch of data.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def repr(self, state, nodes, conns, precision=2):
|
|
nodes, conns = jax.device_get([nodes, conns])
|
|
nodes_cnt, conns_cnt = self.valid_cnt(nodes), self.valid_cnt(conns)
|
|
s = f"{self.__class__.__name__}(nodes={nodes_cnt}, conns={conns_cnt}):\n"
|
|
s += f"\tNodes:\n"
|
|
for node in nodes:
|
|
if np.isnan(node[0]):
|
|
break
|
|
s += f"\t\t{self.node_gene.repr(state, node, precision=precision)}"
|
|
node_idx = int(node[0])
|
|
if np.isin(node_idx, self.input_idx):
|
|
s += " (input)"
|
|
elif np.isin(node_idx, self.output_idx):
|
|
s += " (output)"
|
|
s += "\n"
|
|
|
|
s += f"\tConns:\n"
|
|
for conn in conns:
|
|
if np.isnan(conn[0]):
|
|
break
|
|
s += f"\t\t{self.conn_gene.repr(state, conn, precision=precision)}\n"
|
|
return s
|
|
|
|
@classmethod
|
|
def valid_cnt(cls, arr):
|
|
return jnp.sum(~jnp.isnan(arr[:, 0]))
|
|
|
|
def get_conn_dict(self, state, conns):
|
|
conns = jax.device_get(conns)
|
|
conn_dict = {}
|
|
for conn in conns:
|
|
if np.isnan(conn[0]):
|
|
continue
|
|
cd = self.conn_gene.to_dict(state, conn)
|
|
in_idx, out_idx = cd["in"], cd["out"]
|
|
conn_dict[(in_idx, out_idx)] = cd
|
|
return conn_dict
|
|
|
|
def get_node_dict(self, state, nodes):
|
|
nodes = jax.device_get(nodes)
|
|
node_dict = {}
|
|
for node in nodes:
|
|
if np.isnan(node[0]):
|
|
continue
|
|
nd = self.node_gene.to_dict(state, node)
|
|
idx = nd["idx"]
|
|
node_dict[idx] = nd
|
|
return node_dict
|
|
|
|
def network_dict(self, state, nodes, conns):
|
|
return {
|
|
"nodes": self.get_node_dict(state, nodes),
|
|
"conns": self.get_conn_dict(state, conns),
|
|
}
|
|
|
|
def get_input_idx(self):
|
|
return self.input_idx.tolist()
|
|
|
|
def get_output_idx(self):
|
|
return self.output_idx.tolist()
|
|
|
|
def sympy_func(self, state, network, sympy_output_transform=None):
|
|
raise NotImplementedError
|
|
|
|
def visualize(
|
|
self,
|
|
network,
|
|
rotate=0,
|
|
reverse_node_order=False,
|
|
size=(300, 300, 300),
|
|
color=("blue", "blue", "blue"),
|
|
save_path="network.svg",
|
|
save_dpi=800,
|
|
**kwargs,
|
|
):
|
|
import networkx as nx
|
|
from matplotlib import pyplot as plt
|
|
|
|
nodes_list = list(network["nodes"])
|
|
conns_list = list(network["conns"])
|
|
input_idx = self.get_input_idx()
|
|
output_idx = self.get_output_idx()
|
|
topo_order, topo_layers = topological_sort_python(nodes_list, conns_list)
|
|
node2layer = {
|
|
node: layer for layer, nodes in enumerate(topo_layers) for node in nodes
|
|
}
|
|
if reverse_node_order:
|
|
topo_order = topo_order[::-1]
|
|
|
|
G = nx.DiGraph()
|
|
|
|
if not isinstance(size, tuple):
|
|
size = (size, size, size)
|
|
if not isinstance(color, tuple):
|
|
color = (color, color, color)
|
|
|
|
for node in topo_order:
|
|
if node in input_idx:
|
|
G.add_node(node, subset=node2layer[node], size=size[0], color=color[0])
|
|
elif node in output_idx:
|
|
G.add_node(node, subset=node2layer[node], size=size[2], color=color[2])
|
|
else:
|
|
G.add_node(node, subset=node2layer[node], size=size[1], color=color[1])
|
|
|
|
for conn in conns_list:
|
|
G.add_edge(conn[0], conn[1])
|
|
pos = nx.multipartite_layout(G)
|
|
|
|
def rotate_layout(pos, angle):
|
|
angle_rad = np.deg2rad(angle)
|
|
cos_angle, sin_angle = np.cos(angle_rad), np.sin(angle_rad)
|
|
rotated_pos = {}
|
|
for node, (x, y) in pos.items():
|
|
rotated_pos[node] = (
|
|
cos_angle * x - sin_angle * y,
|
|
sin_angle * x + cos_angle * y,
|
|
)
|
|
return rotated_pos
|
|
|
|
rotated_pos = rotate_layout(pos, rotate)
|
|
|
|
node_sizes = [n["size"] for n in G.nodes.values()]
|
|
node_colors = [n["color"] for n in G.nodes.values()]
|
|
|
|
nx.draw(
|
|
G,
|
|
with_labels=True,
|
|
pos=rotated_pos,
|
|
node_size=node_sizes,
|
|
node_color=node_colors,
|
|
**kwargs,
|
|
)
|
|
plt.savefig(save_path, dpi=save_dpi)
|