update a lot, take a break
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from tensorneat.common import fetch_first, I_INF
|
||||
|
||||
@@ -107,3 +108,33 @@ def delete_conn_by_pos(conns, pos):
|
||||
Delete the connection by its idx.
|
||||
"""
|
||||
return conns.at[pos].set(jnp.nan)
|
||||
|
||||
|
||||
def re_cound_idx(nodes, conns, input_idx, output_idx):
|
||||
"""
|
||||
Make the key of hidden nodes continuous.
|
||||
Also update the index of connections.
|
||||
"""
|
||||
nodes, conns = jax.device_get((nodes, conns))
|
||||
next_key = max(*input_idx, *output_idx) + 1
|
||||
old2new = {}
|
||||
for i, key in enumerate(nodes[:, 0]):
|
||||
if np.isnan(key):
|
||||
continue
|
||||
if np.in1d(key, input_idx + output_idx):
|
||||
continue
|
||||
old2new[int(key)] = next_key
|
||||
next_key += 1
|
||||
|
||||
new_nodes = nodes.copy()
|
||||
for i, key in enumerate(nodes[:, 0]):
|
||||
if (not np.isnan(key)) and int(key) in old2new:
|
||||
new_nodes[i, 0] = old2new[int(key)]
|
||||
|
||||
new_conns = conns.copy()
|
||||
for i, (i_key, o_key) in enumerate(conns[:, :2]):
|
||||
if (not np.isnan(i_key)) and int(i_key) in old2new:
|
||||
new_conns[i, 0] = old2new[int(i_key)]
|
||||
if (not np.isnan(o_key)) and int(o_key) in old2new:
|
||||
new_conns[i, 1] = old2new[int(o_key)]
|
||||
return new_nodes, new_conns
|
||||
|
||||
Reference in New Issue
Block a user