add gene type RNN

This commit is contained in:
wls2002
2023-07-19 15:43:49 +08:00
parent 0a2a9fd1be
commit a684e6584d
18 changed files with 248 additions and 129 deletions

View File

@@ -9,6 +9,7 @@ from jax import jit, Array, numpy as jnp
from ..utils import fetch_first, I_INT
@jit
def topological_sort(nodes: Array, conns: Array) -> Array:
"""
a jit-able version of topological_sort! that's crazy!
@@ -60,21 +61,11 @@ def topological_sort(nodes: Array, conns: Array) -> Array:
return res
def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array:
@jit
def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
"""
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
:param nodes: JAX array
The array of nodes.
:param connections: JAX array
The array of connections.
:param from_idx: int
The index of the starting node.
:param to_idx: int
The index of the ending node.
:return: JAX array
An array indicating if there is a cycle caused by the new connection.
Example:
nodes = jnp.array([
[0],
@@ -83,28 +74,21 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
[3]
])
connections = jnp.array([
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
],
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
]
])
check_cycles(nodes, connections, 3, 2) -> True
check_cycles(nodes, connections, 2, 3) -> False
check_cycles(nodes, connections, 0, 3) -> False
check_cycles(nodes, connections, 1, 0) -> False
check_cycles(nodes, conns, 3, 2) -> True
check_cycles(nodes, conns, 2, 3) -> False
check_cycles(nodes, conns, 0, 3) -> False
check_cycles(nodes, conns, 1, 0) -> False
"""
connections_enable = ~jnp.isnan(connections[0, :, :])
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
conns = conns.at[from_idx, to_idx].set(True)
# conns_enable = ~jnp.isnan(conns[0, :, :])
# conns_enable = conns_enable.at[from_idx, to_idx].set(True)
visited = jnp.full(nodes.shape[0], False)
new_visited = visited.at[to_idx].set(True)
@@ -117,43 +101,42 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
def body_func(carry):
_, visited_ = carry
new_visited_ = jnp.dot(visited_, connections_enable)
new_visited_ = jnp.dot(visited_, conns)
new_visited_ = jnp.logical_or(visited_, new_visited_)
return visited_, new_visited_
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
return visited[from_idx]
if __name__ == '__main__':
nodes = jnp.array([
[0],
[1],
[2],
[3],
[jnp.nan]
])
connections = jnp.array([
[
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
],
[
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
]
]
)
print(topological_sort(nodes, connections))
print(check_cycles(nodes, connections, 3, 2))
print(check_cycles(nodes, connections, 2, 3))
print(check_cycles(nodes, connections, 0, 3))
print(check_cycles(nodes, connections, 1, 0))
# if __name__ == '__main__':
# nodes = jnp.array([
# [0],
# [1],
# [2],
# [3],
# [jnp.nan]
# ])
# connections = jnp.array([
# [
# [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
# [jnp.nan, jnp.nan, 1, 1, jnp.nan],
# [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
# ],
# [
# [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
# [jnp.nan, jnp.nan, 1, 1, jnp.nan],
# [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
# ]
# ]
# )
#
# print(topological_sort(nodes, connections))
#
# print(check_cycles(nodes, connections, 3, 2))
# print(check_cycles(nodes, connections, 2, 3))
# print(check_cycles(nodes, connections, 0, 3))
# print(check_cycles(nodes, connections, 1, 0))