add gene type RNN
This commit is contained in:
@@ -11,7 +11,7 @@ from ..utils import fetch_first
|
||||
|
||||
def initialize_genomes(state: State, gene_type: Type[BaseGene]):
|
||||
o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes
|
||||
o_conns = np.full((state.N, state.CL), np.nan, dtype=np.float32) # original connections
|
||||
o_conns = np.full((state.C, state.CL), np.nan, dtype=np.float32) # original connections
|
||||
|
||||
input_idx = state.input_idx
|
||||
output_idx = state.output_idx
|
||||
|
||||
@@ -9,6 +9,7 @@ from jax import jit, Array, numpy as jnp
|
||||
from ..utils import fetch_first, I_INT
|
||||
|
||||
|
||||
@jit
|
||||
def topological_sort(nodes: Array, conns: Array) -> Array:
|
||||
"""
|
||||
a jit-able version of topological_sort! that's crazy!
|
||||
@@ -60,21 +61,11 @@ def topological_sort(nodes: Array, conns: Array) -> Array:
|
||||
return res
|
||||
|
||||
|
||||
def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array:
|
||||
@jit
|
||||
def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
|
||||
"""
|
||||
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
|
||||
|
||||
:param nodes: JAX array
|
||||
The array of nodes.
|
||||
:param connections: JAX array
|
||||
The array of connections.
|
||||
:param from_idx: int
|
||||
The index of the starting node.
|
||||
:param to_idx: int
|
||||
The index of the ending node.
|
||||
:return: JAX array
|
||||
An array indicating if there is a cycle caused by the new connection.
|
||||
|
||||
Example:
|
||||
nodes = jnp.array([
|
||||
[0],
|
||||
@@ -83,28 +74,21 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
|
||||
[3]
|
||||
])
|
||||
connections = jnp.array([
|
||||
[
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 1, 1],
|
||||
[0, 0, 0, 1],
|
||||
[0, 0, 0, 0]
|
||||
],
|
||||
[
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 1, 1],
|
||||
[0, 0, 0, 1],
|
||||
[0, 0, 0, 0]
|
||||
]
|
||||
])
|
||||
|
||||
check_cycles(nodes, connections, 3, 2) -> True
|
||||
check_cycles(nodes, connections, 2, 3) -> False
|
||||
check_cycles(nodes, connections, 0, 3) -> False
|
||||
check_cycles(nodes, connections, 1, 0) -> False
|
||||
check_cycles(nodes, conns, 3, 2) -> True
|
||||
check_cycles(nodes, conns, 2, 3) -> False
|
||||
check_cycles(nodes, conns, 0, 3) -> False
|
||||
check_cycles(nodes, conns, 1, 0) -> False
|
||||
"""
|
||||
|
||||
connections_enable = ~jnp.isnan(connections[0, :, :])
|
||||
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
|
||||
conns = conns.at[from_idx, to_idx].set(True)
|
||||
# conns_enable = ~jnp.isnan(conns[0, :, :])
|
||||
# conns_enable = conns_enable.at[from_idx, to_idx].set(True)
|
||||
|
||||
visited = jnp.full(nodes.shape[0], False)
|
||||
new_visited = visited.at[to_idx].set(True)
|
||||
@@ -117,43 +101,42 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
|
||||
|
||||
def body_func(carry):
|
||||
_, visited_ = carry
|
||||
new_visited_ = jnp.dot(visited_, connections_enable)
|
||||
new_visited_ = jnp.dot(visited_, conns)
|
||||
new_visited_ = jnp.logical_or(visited_, new_visited_)
|
||||
return visited_, new_visited_
|
||||
|
||||
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
|
||||
return visited[from_idx]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
nodes = jnp.array([
|
||||
[0],
|
||||
[1],
|
||||
[2],
|
||||
[3],
|
||||
[jnp.nan]
|
||||
])
|
||||
connections = jnp.array([
|
||||
[
|
||||
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
|
||||
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
||||
],
|
||||
[
|
||||
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
|
||||
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
print(topological_sort(nodes, connections))
|
||||
|
||||
print(check_cycles(nodes, connections, 3, 2))
|
||||
print(check_cycles(nodes, connections, 2, 3))
|
||||
print(check_cycles(nodes, connections, 0, 3))
|
||||
print(check_cycles(nodes, connections, 1, 0))
|
||||
# if __name__ == '__main__':
|
||||
# nodes = jnp.array([
|
||||
# [0],
|
||||
# [1],
|
||||
# [2],
|
||||
# [3],
|
||||
# [jnp.nan]
|
||||
# ])
|
||||
# connections = jnp.array([
|
||||
# [
|
||||
# [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
|
||||
# [jnp.nan, jnp.nan, 1, 1, jnp.nan],
|
||||
# [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
|
||||
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
|
||||
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
||||
# ],
|
||||
# [
|
||||
# [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
|
||||
# [jnp.nan, jnp.nan, 1, 1, jnp.nan],
|
||||
# [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
|
||||
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
|
||||
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
||||
# ]
|
||||
# ]
|
||||
# )
|
||||
#
|
||||
# print(topological_sort(nodes, connections))
|
||||
#
|
||||
# print(check_cycles(nodes, connections, 3, 2))
|
||||
# print(check_cycles(nodes, connections, 2, 3))
|
||||
# print(check_cycles(nodes, connections, 0, 3))
|
||||
# print(check_cycles(nodes, connections, 1, 0))
|
||||
|
||||
@@ -91,7 +91,8 @@ def create_mutate(config: Dict, gene_type: Type[BaseGene]):
|
||||
|
||||
if config['network_type'] == 'feedforward':
|
||||
u_cons = unflatten_connections(nodes_, conns_)
|
||||
is_cycle = check_cycles(nodes_, u_cons, from_idx, to_idx)
|
||||
cons_exist = jnp.where(~jnp.isnan(u_cons[0, :, :]), True, False)
|
||||
is_cycle = check_cycles(nodes_, cons_exist, from_idx, to_idx)
|
||||
|
||||
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
|
||||
return jax.lax.switch(choice, [already_exist, nothing, successful])
|
||||
|
||||
Reference in New Issue
Block a user