164 lines
4.5 KiB
Python
164 lines
4.5 KiB
Python
"""
|
|
Some graph algorithms implemented in jax.
|
|
Only used in feed-forward networks.
|
|
"""
|
|
|
|
import numpy as np
|
|
from numpy.typing import NDArray
|
|
|
|
# from .utils import fetch_first, I_INT
|
|
from algorithms.neat.genome.utils import fetch_first, I_INT
|
|
|
|
|
|
def topological_sort(nodes: NDArray, connections: NDArray) -> NDArray:
|
|
"""
|
|
a jit-able version of topological_sort! that's crazy!
|
|
:param nodes: nodes array
|
|
:param connections: connections array
|
|
:return: topological sorted sequence
|
|
|
|
Example:
|
|
nodes = np.array([
|
|
[0],
|
|
[1],
|
|
[2],
|
|
[3]
|
|
])
|
|
connections = np.array([
|
|
[
|
|
[0, 0, 1, 0],
|
|
[0, 0, 1, 1],
|
|
[0, 0, 0, 1],
|
|
[0, 0, 0, 0]
|
|
],
|
|
[
|
|
[0, 0, 1, 0],
|
|
[0, 0, 1, 1],
|
|
[0, 0, 0, 1],
|
|
[0, 0, 0, 0]
|
|
]
|
|
])
|
|
|
|
topological_sort(nodes, connections) -> [0, 1, 2, 3]
|
|
"""
|
|
connections_enable = connections[1, :, :] == 1
|
|
in_degree = np.where(np.isnan(nodes[:, 0]), np.nan, np.sum(connections_enable, axis=0))
|
|
res = np.full(in_degree.shape, I_INT)
|
|
idx = 0
|
|
|
|
for _ in range(in_degree.shape[0]):
|
|
i = fetch_first(in_degree == 0.)
|
|
if i == I_INT:
|
|
break
|
|
res[idx] = i
|
|
idx += 1
|
|
in_degree[i] = -1
|
|
children = connections_enable[i, :]
|
|
in_degree = np.where(children, in_degree - 1, in_degree)
|
|
|
|
return res
|
|
|
|
|
|
def batch_topological_sort(pop_nodes: NDArray, pop_connections: NDArray) -> NDArray:
|
|
"""
|
|
batch version of topological_sort
|
|
:param pop_nodes:
|
|
:param pop_connections:
|
|
:return:
|
|
"""
|
|
res = []
|
|
for nodes, connections in zip(pop_nodes, pop_connections):
|
|
seq = topological_sort(nodes, connections)
|
|
res.append(seq)
|
|
return np.stack(res, axis=0)
|
|
|
|
|
|
def check_cycles(nodes: NDArray, connections: NDArray, from_idx: NDArray, to_idx: NDArray) -> NDArray:
|
|
"""
|
|
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
|
|
|
|
:param nodes: JAX array
|
|
The array of nodes.
|
|
:param connections: JAX array
|
|
The array of connections.
|
|
:param from_idx: int
|
|
The index of the starting node.
|
|
:param to_idx: int
|
|
The index of the ending node.
|
|
:return: JAX array
|
|
An array indicating if there is a cycle caused by the new connection.
|
|
|
|
Example:
|
|
nodes = np.array([
|
|
[0],
|
|
[1],
|
|
[2],
|
|
[3]
|
|
])
|
|
connections = np.array([
|
|
[
|
|
[0, 0, 1, 0],
|
|
[0, 0, 1, 1],
|
|
[0, 0, 0, 1],
|
|
[0, 0, 0, 0]
|
|
],
|
|
[
|
|
[0, 0, 1, 0],
|
|
[0, 0, 1, 1],
|
|
[0, 0, 0, 1],
|
|
[0, 0, 0, 0]
|
|
]
|
|
])
|
|
|
|
check_cycles(nodes, connections, 3, 2) -> True
|
|
check_cycles(nodes, connections, 2, 3) -> False
|
|
check_cycles(nodes, connections, 0, 3) -> False
|
|
check_cycles(nodes, connections, 1, 0) -> False
|
|
"""
|
|
connections_enable = ~np.isnan(connections[0, :, :])
|
|
|
|
connections_enable[from_idx, to_idx] = True
|
|
nodes_visited = np.full(nodes.shape[0], False)
|
|
nodes_visited[to_idx] = True
|
|
|
|
for _ in range(nodes_visited.shape[0]):
|
|
new_visited = np.dot(nodes_visited, connections_enable)
|
|
nodes_visited = np.logical_or(nodes_visited, new_visited)
|
|
|
|
return nodes_visited[from_idx]
|
|
|
|
|
|
if __name__ == '__main__':
|
|
nodes = np.array([
|
|
[0],
|
|
[1],
|
|
[2],
|
|
[3],
|
|
[np.nan]
|
|
])
|
|
connections = np.array([
|
|
[
|
|
[np.nan, np.nan, 1, np.nan, np.nan],
|
|
[np.nan, np.nan, 1, 1, np.nan],
|
|
[np.nan, np.nan, np.nan, 1, np.nan],
|
|
[np.nan, np.nan, np.nan, np.nan, np.nan],
|
|
[np.nan, np.nan, np.nan, np.nan, np.nan]
|
|
],
|
|
[
|
|
[np.nan, np.nan, 1, np.nan, np.nan],
|
|
[np.nan, np.nan, 1, 1, np.nan],
|
|
[np.nan, np.nan, np.nan, 1, np.nan],
|
|
[np.nan, np.nan, np.nan, np.nan, np.nan],
|
|
[np.nan, np.nan, np.nan, np.nan, np.nan]
|
|
]
|
|
]
|
|
)
|
|
|
|
print(topological_sort(nodes, connections))
|
|
print(topological_sort(nodes, connections))
|
|
|
|
print(check_cycles(nodes, connections, 3, 2))
|
|
print(check_cycles(nodes, connections, 2, 3))
|
|
print(check_cycles(nodes, connections, 0, 3))
|
|
print(check_cycles(nodes, connections, 1, 0))
|