Files
tensorneat-mend/algorithms/neat/genome/numpy/graph.py
2023-05-06 21:04:28 +08:00

164 lines
4.5 KiB
Python

"""
Some graph algorithms implemented in jax.
Only used in feed-forward networks.
"""
import numpy as np
from numpy.typing import NDArray
# from .utils import fetch_first, I_INT
from algorithms.neat.genome.utils import fetch_first, I_INT
def topological_sort(nodes: NDArray, connections: NDArray) -> NDArray:
"""
a jit-able version of topological_sort! that's crazy!
:param nodes: nodes array
:param connections: connections array
:return: topological sorted sequence
Example:
nodes = np.array([
[0],
[1],
[2],
[3]
])
connections = np.array([
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
],
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
]
])
topological_sort(nodes, connections) -> [0, 1, 2, 3]
"""
connections_enable = connections[1, :, :] == 1
in_degree = np.where(np.isnan(nodes[:, 0]), np.nan, np.sum(connections_enable, axis=0))
res = np.full(in_degree.shape, I_INT)
idx = 0
for _ in range(in_degree.shape[0]):
i = fetch_first(in_degree == 0.)
if i == I_INT:
break
res[idx] = i
idx += 1
in_degree[i] = -1
children = connections_enable[i, :]
in_degree = np.where(children, in_degree - 1, in_degree)
return res
def batch_topological_sort(pop_nodes: NDArray, pop_connections: NDArray) -> NDArray:
"""
batch version of topological_sort
:param pop_nodes:
:param pop_connections:
:return:
"""
res = []
for nodes, connections in zip(pop_nodes, pop_connections):
seq = topological_sort(nodes, connections)
res.append(seq)
return np.stack(res, axis=0)
def check_cycles(nodes: NDArray, connections: NDArray, from_idx: NDArray, to_idx: NDArray) -> NDArray:
"""
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
:param nodes: JAX array
The array of nodes.
:param connections: JAX array
The array of connections.
:param from_idx: int
The index of the starting node.
:param to_idx: int
The index of the ending node.
:return: JAX array
An array indicating if there is a cycle caused by the new connection.
Example:
nodes = np.array([
[0],
[1],
[2],
[3]
])
connections = np.array([
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
],
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
]
])
check_cycles(nodes, connections, 3, 2) -> True
check_cycles(nodes, connections, 2, 3) -> False
check_cycles(nodes, connections, 0, 3) -> False
check_cycles(nodes, connections, 1, 0) -> False
"""
connections_enable = ~np.isnan(connections[0, :, :])
connections_enable[from_idx, to_idx] = True
nodes_visited = np.full(nodes.shape[0], False)
nodes_visited[to_idx] = True
for _ in range(nodes_visited.shape[0]):
new_visited = np.dot(nodes_visited, connections_enable)
nodes_visited = np.logical_or(nodes_visited, new_visited)
return nodes_visited[from_idx]
if __name__ == '__main__':
nodes = np.array([
[0],
[1],
[2],
[3],
[np.nan]
])
connections = np.array([
[
[np.nan, np.nan, 1, np.nan, np.nan],
[np.nan, np.nan, 1, 1, np.nan],
[np.nan, np.nan, np.nan, 1, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
],
[
[np.nan, np.nan, 1, np.nan, np.nan],
[np.nan, np.nan, 1, 1, np.nan],
[np.nan, np.nan, np.nan, 1, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
]
]
)
print(topological_sort(nodes, connections))
print(topological_sort(nodes, connections))
print(check_cycles(nodes, connections, 3, 2))
print(check_cycles(nodes, connections, 2, 3))
print(check_cycles(nodes, connections, 0, 3))
print(check_cycles(nodes, connections, 1, 0))