add sympy support; which can transfer your network into sympy expression;

add visualize in genome;
add related tests.
This commit is contained in:
wls2002
2024-06-12 21:36:35 +08:00
parent dfc8f9198e
commit b3e442c688
29 changed files with 6196 additions and 168 deletions

View File

@@ -1,17 +1,19 @@
from typing import Callable
import jax, jax.numpy as jnp
import sympy as sp
from utils import (
unflatten_conns,
topological_sort,
topological_sort_python,
I_INF,
extract_node_attrs,
extract_conn_attrs,
set_node_attrs,
set_conn_attrs,
attach_with_inf,
FUNCS_MODULE,
)
from . import BaseGenome
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover
@@ -188,3 +190,56 @@ class DefaultGenome(BaseGenome):
jax.vmap(self.output_transform)(batch_vals[:, self.output_idx]),
new_transformed,
)
def sympy_func(self, state, network, precision=3):
input_idx = self.get_input_idx()
output_idx = self.get_output_idx()
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
symbols = {}
for i in network["nodes"]:
if i in input_idx:
symbols[i] = sp.Symbol(f"i{i}")
elif i in output_idx:
symbols[i] = sp.Symbol(f"o{i}")
else: # hidden
symbols[i] = sp.Symbol(f"h{i}")
nodes_exprs = {}
for i in order:
if i in input_idx:
nodes_exprs[symbols[i]] = symbols[i]
else:
in_conns = [c for c in network["conns"] if c[1] == i]
node_inputs = []
for conn in in_conns:
val_represent = symbols[conn[0]]
val = self.conn_gene.sympy_func(
state,
network["conns"][conn],
val_represent,
precision=precision,
)
node_inputs.append(val)
nodes_exprs[symbols[i]] = self.node_gene.sympy_func(
state,
network["nodes"][i],
node_inputs,
is_output_node=(i in output_idx),
precision=precision,
)
input_symbols = [v for k, v in symbols.items() if k in input_idx]
reduced_exprs = nodes_exprs.copy()
for i in order:
reduced_exprs[symbols[i]] = reduced_exprs[symbols[i]].subs(reduced_exprs)
output_exprs = [reduced_exprs[symbols[i]] for i in output_idx]
lambdify_output_funcs = [
sp.lambdify(input_symbols, exprs, modules=["numpy", FUNCS_MODULE])
for exprs in output_exprs
]
forward_func = lambda inputs: [f(*inputs) for f in lambdify_output_funcs]
return symbols, input_symbols, nodes_exprs, output_exprs, forward_func