update a lot, take a break
This commit is contained in:
91
src/tensorneat/common/sympy_tools.py
Normal file
91
src/tensorneat/common/sympy_tools.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import re
|
||||
import sympy as sp
|
||||
|
||||
def analysis_nodes_exprs(nodes_exprs):
|
||||
input_cnt, hidden_cnt, output_cnt = 0, 0, 0
|
||||
norm_symbols = {}
|
||||
for key in nodes_exprs.keys():
|
||||
if str(key).startswith('i'):
|
||||
input_cnt += 1
|
||||
elif str(key).startswith('h'):
|
||||
hidden_cnt += 1
|
||||
elif str(key).startswith('o'):
|
||||
output_cnt += 1
|
||||
elif str(key).startswith('norm'):
|
||||
norm_symbols[key] = nodes_exprs[key]
|
||||
return input_cnt, hidden_cnt, output_cnt, norm_symbols
|
||||
|
||||
def round_expr(expr, precision=2):
|
||||
"""
|
||||
Round numerical values in a sympy expression to a given precision.
|
||||
"""
|
||||
return expr.xreplace({n: round(n, precision) for n in expr.atoms(sp.Number)})
|
||||
|
||||
|
||||
def replace_variable_names(expression):
|
||||
"""
|
||||
Transform sympy expression to a string with array index that can be used in python code.
|
||||
For example, `o0` will be transformed to `o[0]`.
|
||||
"""
|
||||
expression_str = str(expression)
|
||||
expression_str = re.sub(r"\bo(\d+)\b", r"o[\1]", expression_str)
|
||||
expression_str = re.sub(r"\bh(\d+)\b", r"h[\1]", expression_str)
|
||||
expression_str = re.sub(r"\bi(\d+)\b", r"i[\1]", expression_str)
|
||||
return expression_str
|
||||
|
||||
|
||||
def to_latex_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, use_hidden_nodes=True):
|
||||
input_cnt, hidden_cnt, output_cnt, norm_symbols = analysis_nodes_exprs(nodes_exprs)
|
||||
res = "\\begin{align}\n"
|
||||
|
||||
if not use_hidden_nodes:
|
||||
for i in range(output_cnt):
|
||||
expr = output_exprs[i].subs(args_symbols)
|
||||
rounded_expr = round_expr(expr, 2)
|
||||
latex_expr = f"o_{{{sp.latex(i)}}} &= {sp.latex(rounded_expr)}\\newline\n"
|
||||
res += latex_expr
|
||||
else:
|
||||
for i in range(hidden_cnt):
|
||||
symbol = sp.symbols(f"h{i}")
|
||||
expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols)
|
||||
rounded_expr = round_expr(expr, 2)
|
||||
latex_expr = f"h_{{{sp.latex(i)}}} &= {sp.latex(rounded_expr)}\\newline\n"
|
||||
res += latex_expr
|
||||
for i in range(output_cnt):
|
||||
symbol = sp.symbols(f"o{i}")
|
||||
expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols)
|
||||
rounded_expr = round_expr(expr, 2)
|
||||
latex_expr = f"o_{{{sp.latex(i)}}} &= {sp.latex(rounded_expr)}\\newline\n"
|
||||
res += latex_expr
|
||||
res += "\\end{align}\n"
|
||||
return res
|
||||
|
||||
|
||||
def to_python_code(symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, use_hidden_nodes=True):
|
||||
input_cnt, hidden_cnt, output_cnt, norm_symbols = analysis_nodes_exprs(nodes_exprs)
|
||||
res = ""
|
||||
if not use_hidden_nodes:
|
||||
# pre-allocate space
|
||||
res += f"o = np.zeros({output_cnt})\n"
|
||||
for i in range(output_cnt):
|
||||
expr = output_exprs[i].subs(args_symbols)
|
||||
rounded_expr = round_expr(expr, 6)
|
||||
str_expr = f"o{i} = {rounded_expr}"
|
||||
res += replace_variable_names(str_expr) + "\n"
|
||||
else:
|
||||
# pre-allocate space
|
||||
res += f"h = np.zeros({hidden_cnt})\n"
|
||||
res += f"o = np.zeros({output_cnt})\n"
|
||||
for i in range(hidden_cnt):
|
||||
symbol = sp.symbols(f"h{i}")
|
||||
expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols)
|
||||
rounded_expr = round_expr(expr, 6)
|
||||
str_expr = f"h{i} = {rounded_expr}"
|
||||
res += replace_variable_names(str_expr) + "\n"
|
||||
for i in range(output_cnt):
|
||||
symbol = sp.symbols(f"o{i}")
|
||||
expr = nodes_exprs[symbol].subs(args_symbols).subs(norm_symbols)
|
||||
rounded_expr = round_expr(expr, 6)
|
||||
str_expr = f"o{i} = {rounded_expr}"
|
||||
res += replace_variable_names(str_expr) + "\n"
|
||||
return res
|
||||
@@ -10,7 +10,7 @@ from tensorneat.common import (
|
||||
StatefulBaseClass,
|
||||
hash_array,
|
||||
)
|
||||
from .utils import valid_cnt
|
||||
from .utils import valid_cnt, re_cound_idx
|
||||
|
||||
|
||||
class BaseGenome(StatefulBaseClass):
|
||||
@@ -160,7 +160,11 @@ class BaseGenome(StatefulBaseClass):
|
||||
|
||||
return nodes, conns
|
||||
|
||||
def network_dict(self, state, nodes, conns):
|
||||
def network_dict(self, state, nodes, conns, whether_re_cound_idx=True):
|
||||
if whether_re_cound_idx:
|
||||
nodes, conns = re_cound_idx(
|
||||
nodes, conns, self.get_input_idx(), self.get_output_idx()
|
||||
)
|
||||
return {
|
||||
"nodes": self._get_node_dict(state, nodes),
|
||||
"conns": self._get_conn_dict(state, conns),
|
||||
|
||||
@@ -209,7 +209,6 @@ class DefaultNode(BaseNode):
|
||||
bias = sp.symbols(f"n_{nd['idx']}_b")
|
||||
res = sp.symbols(f"n_{nd['idx']}_r")
|
||||
|
||||
print(nd["agg"])
|
||||
z = AGG.obtain_sympy(nd["agg"])(inputs)
|
||||
z = bias + res * z
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from tensorneat.common import fetch_first, I_INF
|
||||
|
||||
@@ -107,3 +108,33 @@ def delete_conn_by_pos(conns, pos):
|
||||
Delete the connection by its idx.
|
||||
"""
|
||||
return conns.at[pos].set(jnp.nan)
|
||||
|
||||
|
||||
def re_cound_idx(nodes, conns, input_idx, output_idx):
|
||||
"""
|
||||
Make the key of hidden nodes continuous.
|
||||
Also update the index of connections.
|
||||
"""
|
||||
nodes, conns = jax.device_get((nodes, conns))
|
||||
next_key = max(*input_idx, *output_idx) + 1
|
||||
old2new = {}
|
||||
for i, key in enumerate(nodes[:, 0]):
|
||||
if np.isnan(key):
|
||||
continue
|
||||
if np.in1d(key, input_idx + output_idx):
|
||||
continue
|
||||
old2new[int(key)] = next_key
|
||||
next_key += 1
|
||||
|
||||
new_nodes = nodes.copy()
|
||||
for i, key in enumerate(nodes[:, 0]):
|
||||
if (not np.isnan(key)) and int(key) in old2new:
|
||||
new_nodes[i, 0] = old2new[int(key)]
|
||||
|
||||
new_conns = conns.copy()
|
||||
for i, (i_key, o_key) in enumerate(conns[:, :2]):
|
||||
if (not np.isnan(i_key)) and int(i_key) in old2new:
|
||||
new_conns[i, 0] = old2new[int(i_key)]
|
||||
if (not np.isnan(o_key)) and int(o_key) in old2new:
|
||||
new_conns[i, 1] = old2new[int(o_key)]
|
||||
return new_nodes, new_conns
|
||||
|
||||
Reference in New Issue
Block a user