update a lot, take a break

This commit is contained in:
root
2024-07-12 07:47:33 +08:00
parent 58c56ab2ab
commit 99b8f7fd90
11 changed files with 2161 additions and 2418 deletions

View File

@@ -10,7 +10,7 @@ from tensorneat.common import (
StatefulBaseClass,
hash_array,
)
from .utils import valid_cnt
from .utils import valid_cnt, re_cound_idx
class BaseGenome(StatefulBaseClass):
@@ -160,7 +160,11 @@ class BaseGenome(StatefulBaseClass):
return nodes, conns
def network_dict(self, state, nodes, conns):
def network_dict(self, state, nodes, conns, whether_re_cound_idx=True):
if whether_re_cound_idx:
nodes, conns = re_cound_idx(
nodes, conns, self.get_input_idx(), self.get_output_idx()
)
return {
"nodes": self._get_node_dict(state, nodes),
"conns": self._get_conn_dict(state, conns),

View File

@@ -209,7 +209,6 @@ class DefaultNode(BaseNode):
bias = sp.symbols(f"n_{nd['idx']}_b")
res = sp.symbols(f"n_{nd['idx']}_r")
print(nd["agg"])
z = AGG.obtain_sympy(nd["agg"])(inputs)
z = bias + res * z

View File

@@ -1,5 +1,6 @@
import jax
from jax import vmap, numpy as jnp
import numpy as np
from tensorneat.common import fetch_first, I_INF
@@ -107,3 +108,33 @@ def delete_conn_by_pos(conns, pos):
Delete the connection by its idx.
"""
return conns.at[pos].set(jnp.nan)
def re_cound_idx(nodes, conns, input_idx, output_idx):
"""
Make the key of hidden nodes continuous.
Also update the index of connections.
"""
nodes, conns = jax.device_get((nodes, conns))
next_key = max(*input_idx, *output_idx) + 1
old2new = {}
for i, key in enumerate(nodes[:, 0]):
if np.isnan(key):
continue
if np.in1d(key, input_idx + output_idx):
continue
old2new[int(key)] = next_key
next_key += 1
new_nodes = nodes.copy()
for i, key in enumerate(nodes[:, 0]):
if (not np.isnan(key)) and int(key) in old2new:
new_nodes[i, 0] = old2new[int(key)]
new_conns = conns.copy()
for i, (i_key, o_key) in enumerate(conns[:, :2]):
if (not np.isnan(i_key)) and int(i_key) in old2new:
new_conns[i, 0] = old2new[int(i_key)]
if (not np.isnan(o_key)) and int(o_key) in old2new:
new_conns[i, 1] = old2new[int(o_key)]
return new_nodes, new_conns