debuging
This commit is contained in:
@@ -42,14 +42,14 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
|
||||
# crossover nodes
|
||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||
nodes2 = align_array(keys1, keys2, nodes2, 'node')
|
||||
new_nodes = jnp.where(jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
|
||||
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
|
||||
|
||||
# crossover connections
|
||||
cons1 = flatten_connections(keys1, connections1)
|
||||
cons2 = flatten_connections(keys2, connections2)
|
||||
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
|
||||
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
|
||||
new_cons = jnp.where(jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
|
||||
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
|
||||
new_cons = unflatten_connections(len(keys1), new_cons)
|
||||
|
||||
return new_nodes, new_cons
|
||||
|
||||
@@ -35,6 +35,15 @@ def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
|
||||
n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2)
|
||||
nodes = jnp.concatenate((ar1, ar2), axis=0)
|
||||
sorted_nodes = nodes[n_sorted_indices]
|
||||
|
||||
if gene_type == 'node':
|
||||
node_exist_mask = jnp.any(~jnp.isnan(sorted_nodes[:, 1:]), axis=1)
|
||||
else:
|
||||
node_exist_mask = jnp.any(~jnp.isnan(sorted_nodes[:, 2:]), axis=1)
|
||||
|
||||
n_intersect_mask = n_intersect_mask & node_exist_mask
|
||||
n_union_mask = n_union_mask & node_exist_mask
|
||||
|
||||
fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:]
|
||||
|
||||
non_homologous_cnt = jnp.sum(n_union_mask) - jnp.sum(n_intersect_mask)
|
||||
@@ -48,9 +57,11 @@ def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
|
||||
|
||||
gene_cnt1 = jnp.sum(jnp.all(~jnp.isnan(ar1), axis=1))
|
||||
gene_cnt2 = jnp.sum(jnp.all(~jnp.isnan(ar2), axis=1))
|
||||
max_cnt = jnp.maximum(gene_cnt1, gene_cnt2)
|
||||
|
||||
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
|
||||
return val / jnp.where(gene_cnt1 > gene_cnt2, gene_cnt1, gene_cnt2)
|
||||
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt) # consider the case that both genome has no gene
|
||||
|
||||
|
||||
@partial(vmap, in_axes=(0, 0))
|
||||
|
||||
@@ -27,7 +27,7 @@ def create_forward_function(nodes: NDArray, connections: NDArray,
|
||||
"""
|
||||
|
||||
if debug:
|
||||
cal_seqs = topological_sort(nodes, connections)
|
||||
cal_seqs = topological_sort_debug(nodes, connections)
|
||||
return lambda inputs: forward_single_debug(inputs, N, input_idx, output_idx,
|
||||
cal_seqs, nodes, connections)
|
||||
|
||||
|
||||
@@ -12,9 +12,10 @@ status.
|
||||
Empty nodes or connections are represented using np.nan.
|
||||
|
||||
"""
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Dict
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from jax import numpy as jnp
|
||||
@@ -131,6 +132,83 @@ def expand_single(nodes: NDArray, connections: NDArray, new_N: int) -> Tuple[NDA
|
||||
|
||||
return new_nodes, new_connections
|
||||
|
||||
|
||||
def analysis(nodes: NDArray, connections: NDArray, input_keys, output_keys) -> \
|
||||
Tuple[Dict[int, Tuple[float, float, int, int]], Dict[Tuple[int, int], Tuple[float, bool]]]:
|
||||
"""
|
||||
Convert a genome from array to dict.
|
||||
:param nodes: (N, 5)
|
||||
:param connections: (2, N, N)
|
||||
:param output_keys:
|
||||
:param input_keys:
|
||||
:return: nodes_dict[key: (bias, response, act, agg)], connections_dict[(f_key, t_key): (weight, enabled)]
|
||||
"""
|
||||
# update nodes_dict
|
||||
try:
|
||||
nodes_dict = {}
|
||||
idx2key = {}
|
||||
for i, node in enumerate(nodes):
|
||||
if np.isnan(node[0]):
|
||||
continue
|
||||
key = int(node[0])
|
||||
assert key not in nodes_dict, f"Duplicate node key: {key}!"
|
||||
|
||||
bias = node[1] if not np.isnan(node[1]) else None
|
||||
response = node[2] if not np.isnan(node[2]) else None
|
||||
act = node[3] if not np.isnan(node[3]) else None
|
||||
agg = node[4] if not np.isnan(node[4]) else None
|
||||
nodes_dict[key] = (bias, response, act, agg)
|
||||
idx2key[i] = key
|
||||
|
||||
# check nodes_dict
|
||||
for i in input_keys:
|
||||
assert i in nodes_dict, f"Input node {i} not found in nodes_dict!"
|
||||
bias, response, act, agg = nodes_dict[i]
|
||||
assert bias is None and response is None and act is None and agg is None, \
|
||||
f"Input node {i} must has None bias, response, act, or agg!"
|
||||
|
||||
for o in output_keys:
|
||||
assert o in nodes_dict, f"Output node {o} not found in nodes_dict!"
|
||||
|
||||
for k, v in nodes_dict.items():
|
||||
if k not in input_keys:
|
||||
bias, response, act, agg = v
|
||||
assert bias is not None and response is not None and act is not None and agg is not None, \
|
||||
f"Normal node {k} must has non-None bias, response, act, or agg!"
|
||||
|
||||
# update connections
|
||||
connections_dict = {}
|
||||
for i in range(connections.shape[1]):
|
||||
for j in range(connections.shape[2]):
|
||||
if np.isnan(connections[0, i, j]) and np.isnan(connections[1, i, j]):
|
||||
continue
|
||||
assert i in idx2key, f"Node index {i} not found in idx2key:{idx2key}!"
|
||||
assert j in idx2key, f"Node index {j} not found in idx2key:{idx2key}!"
|
||||
key = (idx2key[i], idx2key[j])
|
||||
|
||||
weight = connections[0, i, j] if not np.isnan(connections[0, i, j]) else None
|
||||
enabled = (connections[1, i, j] == 1) if not np.isnan(connections[1, i, j]) else None
|
||||
|
||||
assert weight is not None, f"Connection {key} must has non-None weight!"
|
||||
assert enabled is not None, f"Connection {key} must has non-None enabled!"
|
||||
connections_dict[key] = (weight, enabled)
|
||||
|
||||
return nodes_dict, connections_dict
|
||||
except AssertionError:
|
||||
print(nodes)
|
||||
print(connections)
|
||||
raise AssertionError
|
||||
|
||||
|
||||
def pop_analysis(pop_nodes, pop_connections, input_keys, output_keys):
|
||||
pop_nodes, pop_connections = jax.device_get((pop_nodes, pop_connections))
|
||||
res = []
|
||||
for nodes, connections in zip(pop_nodes, pop_connections):
|
||||
res.append(analysis(nodes, connections, input_keys, output_keys))
|
||||
return res
|
||||
|
||||
|
||||
|
||||
@jit
|
||||
def add_node(new_node_key: int, nodes: Array, connections: Array,
|
||||
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]:
|
||||
@@ -158,11 +236,12 @@ def delete_node_by_idx(idx: int, nodes: Array, connections: Array) -> Tuple[Arra
|
||||
"""
|
||||
delete a node from the genome. only delete the node, regardless of connections.
|
||||
"""
|
||||
node_keys = nodes[:, 0]
|
||||
# node_keys = nodes[:, 0]
|
||||
nodes = nodes.at[idx].set(EMPTY_NODE)
|
||||
# move the last node to the deleted node's position
|
||||
last_real_idx = fetch_last(~jnp.isnan(node_keys))
|
||||
nodes = nodes.at[idx].set(nodes[last_real_idx])
|
||||
nodes = nodes.at[last_real_idx].set(EMPTY_NODE)
|
||||
# last_real_idx = fetch_last(~jnp.isnan(node_keys))
|
||||
# nodes = nodes.at[idx].set(nodes[last_real_idx])
|
||||
# nodes = nodes.at[last_real_idx].set(EMPTY_NODE)
|
||||
return nodes, connections
|
||||
|
||||
|
||||
@@ -206,7 +285,3 @@ def delete_connection_by_idx(from_idx: int, to_idx: int, nodes: Array, connectio
|
||||
"""
|
||||
connections = connections.at[:, from_idx, to_idx].set(np.nan)
|
||||
return nodes, connections
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# pop_nodes, pop_connections, input_keys, output_keys = initialize_genomes(100, 5, 2, 1)
|
||||
# print(pop_nodes, pop_connections)
|
||||
|
||||
@@ -148,7 +148,9 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
|
||||
check_cycles(nodes, connections, 0, 3) -> False
|
||||
check_cycles(nodes, connections, 1, 0) -> False
|
||||
"""
|
||||
connections_enable = connections[1, :, :] == 1
|
||||
# connections_enable = connections[0, :, :] == 1
|
||||
connections_enable = ~jnp.isnan(connections[0, :, :])
|
||||
|
||||
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
|
||||
nodes_visited = jnp.full(nodes.shape[0], False)
|
||||
nodes_visited = nodes_visited.at[to_idx].set(True)
|
||||
|
||||
Reference in New Issue
Block a user