odify genome for the official release
This commit is contained in:
@@ -1,8 +1,16 @@
|
||||
from typing import Callable, Sequence
|
||||
|
||||
import numpy as np
|
||||
import jax, jax.numpy as jnp
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
from ..gene import BaseNodeGene, BaseConnGene
|
||||
from ..ga import BaseMutation, BaseCrossover
|
||||
from utils import State, StatefulBaseClass, topological_sort_python, hash_array
|
||||
from .operations import BaseMutation, BaseCrossover, BaseDistance
|
||||
from tensorneat.common import (
|
||||
State,
|
||||
StatefulBaseClass,
|
||||
hash_array,
|
||||
)
|
||||
from .utils import valid_cnt
|
||||
|
||||
|
||||
class BaseGenome(StatefulBaseClass):
|
||||
@@ -18,120 +26,159 @@ class BaseGenome(StatefulBaseClass):
|
||||
conn_gene: BaseConnGene,
|
||||
mutation: BaseMutation,
|
||||
crossover: BaseCrossover,
|
||||
distance: BaseDistance,
|
||||
output_transform: Callable = None,
|
||||
input_transform: Callable = None,
|
||||
init_hidden_layers: Sequence[int] = (),
|
||||
):
|
||||
|
||||
# check transform functions
|
||||
if input_transform is not None:
|
||||
try:
|
||||
_ = input_transform(jnp.zeros(num_inputs))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Output transform function failed: {e}")
|
||||
|
||||
if output_transform is not None:
|
||||
try:
|
||||
_ = output_transform(jnp.zeros(num_outputs))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Output transform function failed: {e}")
|
||||
|
||||
# prepare for initialization
|
||||
all_layers = [num_inputs] + list(init_hidden_layers) + [num_outputs]
|
||||
layer_indices = []
|
||||
next_index = 0
|
||||
for layer in all_layers:
|
||||
layer_indices.append(list(range(next_index, next_index + layer)))
|
||||
next_index += layer
|
||||
|
||||
all_init_nodes = []
|
||||
all_init_conns_in_idx = []
|
||||
all_init_conns_out_idx = []
|
||||
for i in range(len(layer_indices) - 1):
|
||||
in_layer = layer_indices[i]
|
||||
out_layer = layer_indices[i + 1]
|
||||
for in_idx in in_layer:
|
||||
for out_idx in out_layer:
|
||||
all_init_conns_in_idx.append(in_idx)
|
||||
all_init_conns_out_idx.append(out_idx)
|
||||
all_init_nodes.extend(in_layer)
|
||||
|
||||
if max_nodes < len(all_init_nodes):
|
||||
raise ValueError(
|
||||
f"max_nodes={max_nodes} must be greater than or equal to the number of initial nodes={len(all_init_nodes)}"
|
||||
)
|
||||
|
||||
if max_conns < len(all_init_conns_in_idx):
|
||||
raise ValueError(
|
||||
f"max_conns={max_conns} must be greater than or equal to the number of initial connections={len(all_init_conns_in_idx)}"
|
||||
)
|
||||
|
||||
self.num_inputs = num_inputs
|
||||
self.num_outputs = num_outputs
|
||||
self.input_idx = np.arange(num_inputs)
|
||||
self.output_idx = np.arange(num_inputs, num_inputs + num_outputs)
|
||||
self.max_nodes = max_nodes
|
||||
self.max_conns = max_conns
|
||||
self.node_gene = node_gene
|
||||
self.conn_gene = conn_gene
|
||||
self.mutation = mutation
|
||||
self.crossover = crossover
|
||||
self.distance = distance
|
||||
self.output_transform = output_transform
|
||||
self.input_transform = input_transform
|
||||
|
||||
self.input_idx = np.array(layer_indices[0])
|
||||
self.output_idx = np.array(layer_indices[-1])
|
||||
self.all_init_nodes = np.array(all_init_nodes)
|
||||
self.all_init_conns = np.c_[all_init_conns_in_idx, all_init_conns_out_idx]
|
||||
|
||||
def setup(self, state=State()):
|
||||
state = self.node_gene.setup(state)
|
||||
state = self.conn_gene.setup(state)
|
||||
state = self.mutation.setup(state)
|
||||
state = self.crossover.setup(state)
|
||||
state = self.mutation.setup(state, self)
|
||||
state = self.crossover.setup(state, self)
|
||||
state = self.distance.setup(state, self)
|
||||
return state
|
||||
|
||||
def transform(self, state, nodes, conns):
|
||||
raise NotImplementedError
|
||||
|
||||
def restore(self, state, transformed):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, state, transformed, inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def sympy_func(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def visualize(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def execute_mutation(self, state, randkey, nodes, conns, new_node_key):
|
||||
return self.mutation(state, randkey, self, nodes, conns, new_node_key)
|
||||
return self.mutation(state, randkey, nodes, conns, new_node_key)
|
||||
|
||||
def execute_crossover(self, state, randkey, nodes1, conns1, nodes2, conns2):
|
||||
return self.crossover(state, randkey, self, nodes1, conns1, nodes2, conns2)
|
||||
return self.crossover(state, randkey, nodes1, conns1, nodes2, conns2)
|
||||
|
||||
def execute_distance(self, state, nodes1, conns1, nodes2, conns2):
|
||||
return self.distance(state, nodes1, conns1, nodes2, conns2)
|
||||
|
||||
def initialize(self, state, randkey):
|
||||
"""
|
||||
Default initialization method for the genome.
|
||||
Add an extra hidden node.
|
||||
Make all input nodes and output nodes connected to the hidden node.
|
||||
All attributes will be initialized randomly using gene.new_random_attrs method.
|
||||
|
||||
For example, a network with 2 inputs and 1 output, the structure will be:
|
||||
nodes:
|
||||
[
|
||||
[0, attrs0], # input node 0
|
||||
[1, attrs1], # input node 1
|
||||
[2, attrs2], # output node 0
|
||||
[3, attrs3], # hidden node
|
||||
[NaN, NaN], # empty node
|
||||
]
|
||||
conns:
|
||||
[
|
||||
[0, 3, attrs0], # input node 0 -> hidden node
|
||||
[1, 3, attrs1], # input node 1 -> hidden node
|
||||
[3, 2, attrs2], # hidden node -> output node 0
|
||||
[NaN, NaN],
|
||||
[NaN, NaN],
|
||||
]
|
||||
"""
|
||||
|
||||
k1, k2 = jax.random.split(randkey) # k1 for nodes, k2 for conns
|
||||
|
||||
all_nodes_cnt = len(self.all_init_nodes)
|
||||
all_conns_cnt = len(self.all_init_conns)
|
||||
|
||||
# initialize nodes
|
||||
new_node_key = (
|
||||
max([*self.input_idx, *self.output_idx]) + 1
|
||||
) # the key for the hidden node
|
||||
node_keys = jnp.concatenate(
|
||||
[self.input_idx, self.output_idx, jnp.array([new_node_key])]
|
||||
) # the list of all node keys
|
||||
|
||||
# initialize nodes and connections with NaN
|
||||
nodes = jnp.full((self.max_nodes, self.node_gene.length), jnp.nan)
|
||||
conns = jnp.full((self.max_conns, self.conn_gene.length), jnp.nan)
|
||||
# create node indices
|
||||
node_indices = self.all_init_nodes
|
||||
# create node attrs
|
||||
rand_keys_n = jax.random.split(k1, num=all_nodes_cnt)
|
||||
node_attr_func = vmap(self.node_gene.new_random_attrs, in_axes=(None, 0))
|
||||
node_attrs = node_attr_func(state, rand_keys_n)
|
||||
|
||||
# set keys for input nodes, output nodes and hidden node
|
||||
nodes = nodes.at[node_keys, 0].set(node_keys)
|
||||
|
||||
# generate random attributes for nodes
|
||||
node_keys = jax.random.split(k1, len(node_keys))
|
||||
random_node_attrs = jax.vmap(
|
||||
self.node_gene.new_random_attrs, in_axes=(None, 0)
|
||||
)(state, node_keys)
|
||||
nodes = nodes.at[: len(node_keys), 1:].set(random_node_attrs)
|
||||
nodes = nodes.at[:all_nodes_cnt, 0].set(node_indices) # set node indices
|
||||
nodes = nodes.at[:all_nodes_cnt, 1:].set(node_attrs) # set node attrs
|
||||
|
||||
# initialize conns
|
||||
# input-hidden connections
|
||||
input_conns = jnp.c_[
|
||||
self.input_idx, jnp.full_like(self.input_idx, new_node_key)
|
||||
]
|
||||
conns = conns.at[self.input_idx, :2].set(input_conns) # in-keys, out-keys
|
||||
conns = jnp.full((self.max_conns, self.conn_gene.length), jnp.nan)
|
||||
# create input and output indices
|
||||
conn_indices = self.all_init_conns
|
||||
# create conn attrs
|
||||
rand_keys_c = jax.random.split(k2, num=all_conns_cnt)
|
||||
conns_attr_func = jax.vmap(
|
||||
self.conn_gene.new_random_attrs,
|
||||
in_axes=(
|
||||
None,
|
||||
0,
|
||||
),
|
||||
)
|
||||
conns_attrs = conns_attr_func(state, rand_keys_c)
|
||||
|
||||
# output-hidden connections
|
||||
output_conns = jnp.c_[
|
||||
jnp.full_like(self.output_idx, new_node_key), self.output_idx
|
||||
]
|
||||
conns = conns.at[self.output_idx, :2].set(output_conns) # in-keys, out-keys
|
||||
|
||||
conn_keys = jax.random.split(k2, num=len(self.input_idx) + len(self.output_idx))
|
||||
# generate random attributes for conns
|
||||
random_conn_attrs = jax.vmap(
|
||||
self.conn_gene.new_random_attrs, in_axes=(None, 0)
|
||||
)(state, conn_keys)
|
||||
conns = conns.at[: len(conn_keys), 2:].set(random_conn_attrs)
|
||||
conns = conns.at[:all_conns_cnt, :2].set(conn_indices) # set conn indices
|
||||
conns = conns.at[:all_conns_cnt, 2:].set(conns_attrs) # set conn attrs
|
||||
|
||||
return nodes, conns
|
||||
|
||||
def update_by_batch(self, state, batch_input, transformed):
|
||||
"""
|
||||
Update the genome by a batch of data.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
def network_dict(self, state, nodes, conns):
|
||||
return {
|
||||
"nodes": self._get_node_dict(state, nodes),
|
||||
"conns": self._get_conn_dict(state, conns),
|
||||
}
|
||||
|
||||
def get_input_idx(self):
|
||||
return self.input_idx.tolist()
|
||||
|
||||
def get_output_idx(self):
|
||||
return self.output_idx.tolist()
|
||||
|
||||
def hash(self, nodes, conns):
|
||||
nodes_hashs = vmap(hash_array)(nodes)
|
||||
conns_hashs = vmap(hash_array)(conns)
|
||||
return hash_array(jnp.concatenate([nodes_hashs, conns_hashs]))
|
||||
|
||||
def repr(self, state, nodes, conns, precision=2):
|
||||
nodes, conns = jax.device_get([nodes, conns])
|
||||
nodes_cnt, conns_cnt = self.valid_cnt(nodes), self.valid_cnt(conns)
|
||||
nodes_cnt, conns_cnt = valid_cnt(nodes), valid_cnt(conns)
|
||||
s = f"{self.__class__.__name__}(nodes={nodes_cnt}, conns={conns_cnt}):\n"
|
||||
s += f"\tNodes:\n"
|
||||
for node in nodes:
|
||||
@@ -152,11 +199,7 @@ class BaseGenome(StatefulBaseClass):
|
||||
s += f"\t\t{self.conn_gene.repr(state, conn, precision=precision)}\n"
|
||||
return s
|
||||
|
||||
@classmethod
|
||||
def valid_cnt(cls, arr):
|
||||
return jnp.sum(~jnp.isnan(arr[:, 0]))
|
||||
|
||||
def get_conn_dict(self, state, conns):
|
||||
def _get_conn_dict(self, state, conns):
|
||||
conns = jax.device_get(conns)
|
||||
conn_dict = {}
|
||||
for conn in conns:
|
||||
@@ -167,7 +210,7 @@ class BaseGenome(StatefulBaseClass):
|
||||
conn_dict[(in_idx, out_idx)] = cd
|
||||
return conn_dict
|
||||
|
||||
def get_node_dict(self, state, nodes):
|
||||
def _get_node_dict(self, state, nodes):
|
||||
nodes = jax.device_get(nodes)
|
||||
node_dict = {}
|
||||
for node in nodes:
|
||||
@@ -177,92 +220,3 @@ class BaseGenome(StatefulBaseClass):
|
||||
idx = nd["idx"]
|
||||
node_dict[idx] = nd
|
||||
return node_dict
|
||||
|
||||
def network_dict(self, state, nodes, conns):
|
||||
return {
|
||||
"nodes": self.get_node_dict(state, nodes),
|
||||
"conns": self.get_conn_dict(state, conns),
|
||||
}
|
||||
|
||||
def get_input_idx(self):
|
||||
return self.input_idx.tolist()
|
||||
|
||||
def get_output_idx(self):
|
||||
return self.output_idx.tolist()
|
||||
|
||||
def sympy_func(self, state, network, sympy_output_transform=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
network,
|
||||
rotate=0,
|
||||
reverse_node_order=False,
|
||||
size=(300, 300, 300),
|
||||
color=("blue", "blue", "blue"),
|
||||
save_path="network.svg",
|
||||
save_dpi=800,
|
||||
**kwargs,
|
||||
):
|
||||
import networkx as nx
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
nodes_list = list(network["nodes"])
|
||||
conns_list = list(network["conns"])
|
||||
input_idx = self.get_input_idx()
|
||||
output_idx = self.get_output_idx()
|
||||
topo_order, topo_layers = topological_sort_python(nodes_list, conns_list)
|
||||
node2layer = {
|
||||
node: layer for layer, nodes in enumerate(topo_layers) for node in nodes
|
||||
}
|
||||
if reverse_node_order:
|
||||
topo_order = topo_order[::-1]
|
||||
|
||||
G = nx.DiGraph()
|
||||
|
||||
if not isinstance(size, tuple):
|
||||
size = (size, size, size)
|
||||
if not isinstance(color, tuple):
|
||||
color = (color, color, color)
|
||||
|
||||
for node in topo_order:
|
||||
if node in input_idx:
|
||||
G.add_node(node, subset=node2layer[node], size=size[0], color=color[0])
|
||||
elif node in output_idx:
|
||||
G.add_node(node, subset=node2layer[node], size=size[2], color=color[2])
|
||||
else:
|
||||
G.add_node(node, subset=node2layer[node], size=size[1], color=color[1])
|
||||
|
||||
for conn in conns_list:
|
||||
G.add_edge(conn[0], conn[1])
|
||||
pos = nx.multipartite_layout(G)
|
||||
|
||||
def rotate_layout(pos, angle):
|
||||
angle_rad = np.deg2rad(angle)
|
||||
cos_angle, sin_angle = np.cos(angle_rad), np.sin(angle_rad)
|
||||
rotated_pos = {}
|
||||
for node, (x, y) in pos.items():
|
||||
rotated_pos[node] = (
|
||||
cos_angle * x - sin_angle * y,
|
||||
sin_angle * x + cos_angle * y,
|
||||
)
|
||||
return rotated_pos
|
||||
|
||||
rotated_pos = rotate_layout(pos, rotate)
|
||||
|
||||
node_sizes = [n["size"] for n in G.nodes.values()]
|
||||
node_colors = [n["color"] for n in G.nodes.values()]
|
||||
|
||||
nx.draw(
|
||||
G,
|
||||
pos=rotated_pos,
|
||||
node_size=node_sizes,
|
||||
node_color=node_colors,
|
||||
**kwargs,
|
||||
)
|
||||
plt.savefig(save_path, dpi=save_dpi)
|
||||
|
||||
def hash(self, nodes, conns):
|
||||
nodes_hashs = jax.vmap(hash_array)(nodes)
|
||||
conns_hashs = jax.vmap(hash_array)(conns)
|
||||
return hash_array(jnp.concatenate([nodes_hashs, conns_hashs]))
|
||||
|
||||
Reference in New Issue
Block a user