Try to reduce difference between sympy formula and network.

Spend a whole night. Failed;
I'll never try it anymore.
This commit is contained in:
wls2002
2024-06-13 07:07:43 +08:00
parent 69d73aab73
commit aac9f4c3fb
7 changed files with 313 additions and 87 deletions

View File

@@ -82,7 +82,7 @@ class DefaultConnGene(BaseConnGene):
return {
"in": int(conn[0]),
"out": int(conn[1]),
"weight": np.array(conn[2], dtype=np.float32),
"weight": jnp.float32(conn[2]),
}
def sympy_func(self, state, conn_dict, inputs, precision=None):

View File

@@ -163,8 +163,8 @@ class DefaultNodeGene(BaseNodeGene):
idx, bias, res, agg, act = node
idx = int(idx)
bias = np.array(bias, dtype=np.float32)
res = np.array(res, dtype=np.float32)
bias = jnp.float32(bias)
res = jnp.float32(res)
agg = int(agg)
act = int(act)
@@ -186,7 +186,7 @@ class DefaultNodeGene(BaseNodeGene):
res = sp.symbols(f"n_{nd['idx']}_r")
z = convert_to_sympy(nd["agg"])(inputs)
z = bias + z * res
z = bias + res * z
if is_output_node:
pass

View File

@@ -137,7 +137,7 @@ class NodeGeneWithoutResponse(BaseNodeGene):
idx = int(idx)
bias = np.array(bias, dtype=np.float32)
bias = jnp.float32(bias)
agg = int(agg)
act = int(act)

View File

@@ -216,7 +216,6 @@ class DefaultGenome(BaseGenome):
symbols[i] = sp.Symbol(f"h{i}")
nodes_exprs = {}
args_symbols = {}
for i in order: