Try to reduce difference between sympy formula and network.
Spend a whole night. Failed; I'll never try it anymore.
This commit is contained in:
@@ -82,7 +82,7 @@ class DefaultConnGene(BaseConnGene):
|
|||||||
return {
|
return {
|
||||||
"in": int(conn[0]),
|
"in": int(conn[0]),
|
||||||
"out": int(conn[1]),
|
"out": int(conn[1]),
|
||||||
"weight": np.array(conn[2], dtype=np.float32),
|
"weight": jnp.float32(conn[2]),
|
||||||
}
|
}
|
||||||
|
|
||||||
def sympy_func(self, state, conn_dict, inputs, precision=None):
|
def sympy_func(self, state, conn_dict, inputs, precision=None):
|
||||||
|
|||||||
@@ -163,8 +163,8 @@ class DefaultNodeGene(BaseNodeGene):
|
|||||||
idx, bias, res, agg, act = node
|
idx, bias, res, agg, act = node
|
||||||
|
|
||||||
idx = int(idx)
|
idx = int(idx)
|
||||||
bias = np.array(bias, dtype=np.float32)
|
bias = jnp.float32(bias)
|
||||||
res = np.array(res, dtype=np.float32)
|
res = jnp.float32(res)
|
||||||
agg = int(agg)
|
agg = int(agg)
|
||||||
act = int(act)
|
act = int(act)
|
||||||
|
|
||||||
@@ -186,7 +186,7 @@ class DefaultNodeGene(BaseNodeGene):
|
|||||||
res = sp.symbols(f"n_{nd['idx']}_r")
|
res = sp.symbols(f"n_{nd['idx']}_r")
|
||||||
|
|
||||||
z = convert_to_sympy(nd["agg"])(inputs)
|
z = convert_to_sympy(nd["agg"])(inputs)
|
||||||
z = bias + z * res
|
z = bias + res * z
|
||||||
|
|
||||||
if is_output_node:
|
if is_output_node:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -137,7 +137,7 @@ class NodeGeneWithoutResponse(BaseNodeGene):
|
|||||||
|
|
||||||
idx = int(idx)
|
idx = int(idx)
|
||||||
|
|
||||||
bias = np.array(bias, dtype=np.float32)
|
bias = jnp.float32(bias)
|
||||||
agg = int(agg)
|
agg = int(agg)
|
||||||
act = int(act)
|
act = int(act)
|
||||||
|
|
||||||
|
|||||||
@@ -216,7 +216,6 @@ class DefaultGenome(BaseGenome):
|
|||||||
symbols[i] = sp.Symbol(f"h{i}")
|
symbols[i] = sp.Symbol(f"h{i}")
|
||||||
|
|
||||||
nodes_exprs = {}
|
nodes_exprs = {}
|
||||||
|
|
||||||
args_symbols = {}
|
args_symbols = {}
|
||||||
for i in order:
|
for i in order:
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -9,23 +9,19 @@ class Agg:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sum(z):
|
def sum(z):
|
||||||
z = jnp.where(jnp.isnan(z), 0, z)
|
return jnp.sum(z, axis=0, where=~jnp.isnan(z))
|
||||||
return jnp.sum(z, axis=0)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def product(z):
|
def product(z):
|
||||||
z = jnp.where(jnp.isnan(z), 1, z)
|
return jnp.prod(z, axis=0, where=~jnp.isnan(z))
|
||||||
return jnp.prod(z, axis=0)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def max(z):
|
def max(z):
|
||||||
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
|
return jnp.max(z, axis=0, where=~jnp.isnan(z))
|
||||||
return jnp.max(z, axis=0)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def min(z):
|
def min(z):
|
||||||
z = jnp.where(jnp.isnan(z), jnp.inf, z)
|
return jnp.min(z, axis=0, where=~jnp.isnan(z))
|
||||||
return jnp.min(z, axis=0)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def maxabs(z):
|
def maxabs(z):
|
||||||
|
|||||||
@@ -7,6 +7,10 @@ class SympySum(sp.Function):
|
|||||||
def eval(cls, z):
|
def eval(cls, z):
|
||||||
return sp.Add(*z)
|
return sp.Add(*z)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def numerical_eval(cls, z, backend=np):
|
||||||
|
return backend.sum(z)
|
||||||
|
|
||||||
|
|
||||||
class SympyProduct(sp.Function):
|
class SympyProduct(sp.Function):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user