odify genome for the official release

This commit is contained in:
root
2024-07-10 11:24:11 +08:00
parent 075460f896
commit ee8ec84202
83 changed files with 588 additions and 611 deletions

View File

@@ -1,8 +1,16 @@
from typing import Callable, Sequence
import numpy as np
import jax, jax.numpy as jnp
import jax
from jax import vmap, numpy as jnp
from ..gene import BaseNodeGene, BaseConnGene
from ..ga import BaseMutation, BaseCrossover
from utils import State, StatefulBaseClass, topological_sort_python, hash_array
from .operations import BaseMutation, BaseCrossover, BaseDistance
from tensorneat.common import (
State,
StatefulBaseClass,
hash_array,
)
from .utils import valid_cnt
class BaseGenome(StatefulBaseClass):
@@ -18,120 +26,159 @@ class BaseGenome(StatefulBaseClass):
conn_gene: BaseConnGene,
mutation: BaseMutation,
crossover: BaseCrossover,
distance: BaseDistance,
output_transform: Callable = None,
input_transform: Callable = None,
init_hidden_layers: Sequence[int] = (),
):
# check transform functions
if input_transform is not None:
try:
_ = input_transform(jnp.zeros(num_inputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
if output_transform is not None:
try:
_ = output_transform(jnp.zeros(num_outputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
# prepare for initialization
all_layers = [num_inputs] + list(init_hidden_layers) + [num_outputs]
layer_indices = []
next_index = 0
for layer in all_layers:
layer_indices.append(list(range(next_index, next_index + layer)))
next_index += layer
all_init_nodes = []
all_init_conns_in_idx = []
all_init_conns_out_idx = []
for i in range(len(layer_indices) - 1):
in_layer = layer_indices[i]
out_layer = layer_indices[i + 1]
for in_idx in in_layer:
for out_idx in out_layer:
all_init_conns_in_idx.append(in_idx)
all_init_conns_out_idx.append(out_idx)
all_init_nodes.extend(in_layer)
if max_nodes < len(all_init_nodes):
raise ValueError(
f"max_nodes={max_nodes} must be greater than or equal to the number of initial nodes={len(all_init_nodes)}"
)
if max_conns < len(all_init_conns_in_idx):
raise ValueError(
f"max_conns={max_conns} must be greater than or equal to the number of initial connections={len(all_init_conns_in_idx)}"
)
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.input_idx = np.arange(num_inputs)
self.output_idx = np.arange(num_inputs, num_inputs + num_outputs)
self.max_nodes = max_nodes
self.max_conns = max_conns
self.node_gene = node_gene
self.conn_gene = conn_gene
self.mutation = mutation
self.crossover = crossover
self.distance = distance
self.output_transform = output_transform
self.input_transform = input_transform
self.input_idx = np.array(layer_indices[0])
self.output_idx = np.array(layer_indices[-1])
self.all_init_nodes = np.array(all_init_nodes)
self.all_init_conns = np.c_[all_init_conns_in_idx, all_init_conns_out_idx]
def setup(self, state=State()):
state = self.node_gene.setup(state)
state = self.conn_gene.setup(state)
state = self.mutation.setup(state)
state = self.crossover.setup(state)
state = self.mutation.setup(state, self)
state = self.crossover.setup(state, self)
state = self.distance.setup(state, self)
return state
def transform(self, state, nodes, conns):
raise NotImplementedError
def restore(self, state, transformed):
raise NotImplementedError
def forward(self, state, transformed, inputs):
raise NotImplementedError
def sympy_func(self):
raise NotImplementedError
def visualize(self):
raise NotImplementedError
def execute_mutation(self, state, randkey, nodes, conns, new_node_key):
return self.mutation(state, randkey, self, nodes, conns, new_node_key)
return self.mutation(state, randkey, nodes, conns, new_node_key)
def execute_crossover(self, state, randkey, nodes1, conns1, nodes2, conns2):
return self.crossover(state, randkey, self, nodes1, conns1, nodes2, conns2)
return self.crossover(state, randkey, nodes1, conns1, nodes2, conns2)
def execute_distance(self, state, nodes1, conns1, nodes2, conns2):
return self.distance(state, nodes1, conns1, nodes2, conns2)
def initialize(self, state, randkey):
"""
Default initialization method for the genome.
Add an extra hidden node.
Make all input nodes and output nodes connected to the hidden node.
All attributes will be initialized randomly using gene.new_random_attrs method.
For example, a network with 2 inputs and 1 output, the structure will be:
nodes:
[
[0, attrs0], # input node 0
[1, attrs1], # input node 1
[2, attrs2], # output node 0
[3, attrs3], # hidden node
[NaN, NaN], # empty node
]
conns:
[
[0, 3, attrs0], # input node 0 -> hidden node
[1, 3, attrs1], # input node 1 -> hidden node
[3, 2, attrs2], # hidden node -> output node 0
[NaN, NaN],
[NaN, NaN],
]
"""
k1, k2 = jax.random.split(randkey) # k1 for nodes, k2 for conns
all_nodes_cnt = len(self.all_init_nodes)
all_conns_cnt = len(self.all_init_conns)
# initialize nodes
new_node_key = (
max([*self.input_idx, *self.output_idx]) + 1
) # the key for the hidden node
node_keys = jnp.concatenate(
[self.input_idx, self.output_idx, jnp.array([new_node_key])]
) # the list of all node keys
# initialize nodes and connections with NaN
nodes = jnp.full((self.max_nodes, self.node_gene.length), jnp.nan)
conns = jnp.full((self.max_conns, self.conn_gene.length), jnp.nan)
# create node indices
node_indices = self.all_init_nodes
# create node attrs
rand_keys_n = jax.random.split(k1, num=all_nodes_cnt)
node_attr_func = vmap(self.node_gene.new_random_attrs, in_axes=(None, 0))
node_attrs = node_attr_func(state, rand_keys_n)
# set keys for input nodes, output nodes and hidden node
nodes = nodes.at[node_keys, 0].set(node_keys)
# generate random attributes for nodes
node_keys = jax.random.split(k1, len(node_keys))
random_node_attrs = jax.vmap(
self.node_gene.new_random_attrs, in_axes=(None, 0)
)(state, node_keys)
nodes = nodes.at[: len(node_keys), 1:].set(random_node_attrs)
nodes = nodes.at[:all_nodes_cnt, 0].set(node_indices) # set node indices
nodes = nodes.at[:all_nodes_cnt, 1:].set(node_attrs) # set node attrs
# initialize conns
# input-hidden connections
input_conns = jnp.c_[
self.input_idx, jnp.full_like(self.input_idx, new_node_key)
]
conns = conns.at[self.input_idx, :2].set(input_conns) # in-keys, out-keys
conns = jnp.full((self.max_conns, self.conn_gene.length), jnp.nan)
# create input and output indices
conn_indices = self.all_init_conns
# create conn attrs
rand_keys_c = jax.random.split(k2, num=all_conns_cnt)
conns_attr_func = jax.vmap(
self.conn_gene.new_random_attrs,
in_axes=(
None,
0,
),
)
conns_attrs = conns_attr_func(state, rand_keys_c)
# output-hidden connections
output_conns = jnp.c_[
jnp.full_like(self.output_idx, new_node_key), self.output_idx
]
conns = conns.at[self.output_idx, :2].set(output_conns) # in-keys, out-keys
conn_keys = jax.random.split(k2, num=len(self.input_idx) + len(self.output_idx))
# generate random attributes for conns
random_conn_attrs = jax.vmap(
self.conn_gene.new_random_attrs, in_axes=(None, 0)
)(state, conn_keys)
conns = conns.at[: len(conn_keys), 2:].set(random_conn_attrs)
conns = conns.at[:all_conns_cnt, :2].set(conn_indices) # set conn indices
conns = conns.at[:all_conns_cnt, 2:].set(conns_attrs) # set conn attrs
return nodes, conns
def update_by_batch(self, state, batch_input, transformed):
"""
Update the genome by a batch of data.
"""
raise NotImplementedError
def network_dict(self, state, nodes, conns):
return {
"nodes": self._get_node_dict(state, nodes),
"conns": self._get_conn_dict(state, conns),
}
def get_input_idx(self):
return self.input_idx.tolist()
def get_output_idx(self):
return self.output_idx.tolist()
def hash(self, nodes, conns):
nodes_hashs = vmap(hash_array)(nodes)
conns_hashs = vmap(hash_array)(conns)
return hash_array(jnp.concatenate([nodes_hashs, conns_hashs]))
def repr(self, state, nodes, conns, precision=2):
nodes, conns = jax.device_get([nodes, conns])
nodes_cnt, conns_cnt = self.valid_cnt(nodes), self.valid_cnt(conns)
nodes_cnt, conns_cnt = valid_cnt(nodes), valid_cnt(conns)
s = f"{self.__class__.__name__}(nodes={nodes_cnt}, conns={conns_cnt}):\n"
s += f"\tNodes:\n"
for node in nodes:
@@ -152,11 +199,7 @@ class BaseGenome(StatefulBaseClass):
s += f"\t\t{self.conn_gene.repr(state, conn, precision=precision)}\n"
return s
@classmethod
def valid_cnt(cls, arr):
return jnp.sum(~jnp.isnan(arr[:, 0]))
def get_conn_dict(self, state, conns):
def _get_conn_dict(self, state, conns):
conns = jax.device_get(conns)
conn_dict = {}
for conn in conns:
@@ -167,7 +210,7 @@ class BaseGenome(StatefulBaseClass):
conn_dict[(in_idx, out_idx)] = cd
return conn_dict
def get_node_dict(self, state, nodes):
def _get_node_dict(self, state, nodes):
nodes = jax.device_get(nodes)
node_dict = {}
for node in nodes:
@@ -177,92 +220,3 @@ class BaseGenome(StatefulBaseClass):
idx = nd["idx"]
node_dict[idx] = nd
return node_dict
def network_dict(self, state, nodes, conns):
return {
"nodes": self.get_node_dict(state, nodes),
"conns": self.get_conn_dict(state, conns),
}
def get_input_idx(self):
return self.input_idx.tolist()
def get_output_idx(self):
return self.output_idx.tolist()
def sympy_func(self, state, network, sympy_output_transform=None):
raise NotImplementedError
def visualize(
self,
network,
rotate=0,
reverse_node_order=False,
size=(300, 300, 300),
color=("blue", "blue", "blue"),
save_path="network.svg",
save_dpi=800,
**kwargs,
):
import networkx as nx
from matplotlib import pyplot as plt
nodes_list = list(network["nodes"])
conns_list = list(network["conns"])
input_idx = self.get_input_idx()
output_idx = self.get_output_idx()
topo_order, topo_layers = topological_sort_python(nodes_list, conns_list)
node2layer = {
node: layer for layer, nodes in enumerate(topo_layers) for node in nodes
}
if reverse_node_order:
topo_order = topo_order[::-1]
G = nx.DiGraph()
if not isinstance(size, tuple):
size = (size, size, size)
if not isinstance(color, tuple):
color = (color, color, color)
for node in topo_order:
if node in input_idx:
G.add_node(node, subset=node2layer[node], size=size[0], color=color[0])
elif node in output_idx:
G.add_node(node, subset=node2layer[node], size=size[2], color=color[2])
else:
G.add_node(node, subset=node2layer[node], size=size[1], color=color[1])
for conn in conns_list:
G.add_edge(conn[0], conn[1])
pos = nx.multipartite_layout(G)
def rotate_layout(pos, angle):
angle_rad = np.deg2rad(angle)
cos_angle, sin_angle = np.cos(angle_rad), np.sin(angle_rad)
rotated_pos = {}
for node, (x, y) in pos.items():
rotated_pos[node] = (
cos_angle * x - sin_angle * y,
sin_angle * x + cos_angle * y,
)
return rotated_pos
rotated_pos = rotate_layout(pos, rotate)
node_sizes = [n["size"] for n in G.nodes.values()]
node_colors = [n["color"] for n in G.nodes.values()]
nx.draw(
G,
pos=rotated_pos,
node_size=node_sizes,
node_color=node_colors,
**kwargs,
)
plt.savefig(save_path, dpi=save_dpi)
def hash(self, nodes, conns):
nodes_hashs = jax.vmap(hash_array)(nodes)
conns_hashs = jax.vmap(hash_array)(conns)
return hash_array(jnp.concatenate([nodes_hashs, conns_hashs]))