finish ask part of the algorithm;

use jax.lax.while_loop in graph algorithms and forward function;
fix "enabled not care" bug in forward
This commit is contained in:
wls2002
2023-06-25 00:26:52 +08:00
parent 86820db5a6
commit 0cb2f9473d
24 changed files with 485 additions and 1623 deletions

View File

@@ -8,8 +8,7 @@ from jax import jit, vmap, Array
from jax import numpy as jnp
# from .configs import fetch_first, I_INT
from neat.genome.utils import fetch_first, I_INT
from .utils import unflatten_connections
from neat.genome.utils import fetch_first, I_INT, unflatten_connections
@jit
@@ -44,49 +43,32 @@ def topological_sort(nodes: Array, connections: Array) -> Array:
topological_sort(nodes, connections) -> [0, 1, 2, 3]
"""
connections_enable = connections[1, :, :] == 1
connections_enable = connections[1, :, :] == 1 # forward function. thus use enable
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0))
res = jnp.full(in_degree.shape, I_INT)
idx = 0
def scan_body(carry, _):
def cond_fun(carry):
res_, idx_, in_degree_ = carry
i = fetch_first(in_degree_ == 0.)
return i != I_INT
def body_func(carry):
res_, idx_, in_degree_ = carry
i = fetch_first(in_degree_ == 0.)
def hit():
# add to res and flag it is already in it
new_res = res_.at[idx_].set(i)
new_idx = idx_ + 1
new_in_degree = in_degree_.at[i].set(-1)
# add to res and flag it is already in it
res_ = res_.at[idx_].set(i)
in_degree_ = in_degree_.at[i].set(-1)
# decrease in_degree of all its children
children = connections_enable[i, :]
new_in_degree = jnp.where(children, new_in_degree - 1, new_in_degree)
return new_res, new_idx, new_in_degree
def miss():
return res_, idx_, in_degree_
return jax.lax.cond(i == I_INT, miss, hit), None
scan_res, _ = jax.lax.scan(scan_body, (res, idx, in_degree), None, length=in_degree.shape[0])
res, _, _ = scan_res
# decrease in_degree of all its children
children = connections_enable[i, :]
in_degree_ = jnp.where(children, in_degree_ - 1, in_degree_)
return res_, idx_ + 1, in_degree_
res, _, _ = jax.lax.while_loop(cond_fun, body_func, (res, 0, in_degree))
return res
@jit
@vmap
def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array:
"""
batch version of topological_sort
:param pop_nodes:
:param pop_connections:
:return:
"""
return topological_sort(pop_nodes, pop_connections)
@jit
def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array:
"""
@@ -131,22 +113,26 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
check_cycles(nodes, connections, 1, 0) -> False
"""
connections = unflatten_connections(nodes, connections)
connections_enable = ~jnp.isnan(connections[0, :, :])
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
nodes_visited = jnp.full(nodes.shape[0], False)
nodes_visited = nodes_visited.at[to_idx].set(True)
def scan_body(visited, _):
new_visited = jnp.dot(visited, connections_enable)
new_visited = jnp.logical_or(visited, new_visited)
return new_visited, None
visited = jnp.full(nodes.shape[0], False)
new_visited = visited.at[to_idx].set(True)
nodes_visited, _ = jax.lax.scan(scan_body, nodes_visited, None, length=nodes_visited.shape[0])
def cond_func(carry):
visited_, new_visited_ = carry
end_cond1 = jnp.all(visited_ == new_visited_) # no new nodes been visited
end_cond2 = new_visited_[from_idx] # the starting node has been visited
return jnp.logical_not(end_cond1 | end_cond2)
return nodes_visited[from_idx]
def body_func(carry):
_, visited_ = carry
new_visited_ = jnp.dot(visited_, connections_enable)
new_visited_ = jnp.logical_or(visited_, new_visited_)
return visited_, new_visited_
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
return visited[from_idx]
if __name__ == '__main__':