This commit is contained in:
wls2002
2024-06-20 16:32:52 +08:00
parent 9f72813c35
commit 075460f896
17 changed files with 224 additions and 140 deletions

View File

@@ -210,7 +210,14 @@ class DefaultGenome(BaseGenome):
new_transformed,
)
def sympy_func(self, state, network, sympy_input_transform=None, sympy_output_transform=None, backend="jax"):
def sympy_func(
self,
state,
network,
sympy_input_transform=None,
sympy_output_transform=None,
backend="jax",
):
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP
@@ -219,6 +226,10 @@ class DefaultGenome(BaseGenome):
warnings.warn(
"genome.input_transform is not None but sympy_input_transform is None!"
)
if sympy_input_transform is None:
sympy_input_transform = lambda x: x
if sympy_input_transform is not None:
if not isinstance(sympy_input_transform, list):
sympy_input_transform = [sympy_input_transform] * self.num_inputs
@@ -231,11 +242,14 @@ class DefaultGenome(BaseGenome):
input_idx = self.get_input_idx()
output_idx = self.get_output_idx()
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
hidden_idx = [i for i in network["nodes"] if i not in input_idx and i not in output_idx]
hidden_idx = [
i for i in network["nodes"] if i not in input_idx and i not in output_idx
]
symbols = {}
for i in network["nodes"]:
if i in input_idx:
symbols[i] = sp.Symbol(f"i{i - min(input_idx)}")
symbols[-i - 1] = sp.Symbol(f"i{i - min(input_idx)}") # origin_i
symbols[i] = sp.Symbol(f"norm{i - min(input_idx)}")
elif i in output_idx:
symbols[i] = sp.Symbol(f"o{i - min(output_idx)}")
else: # hidden
@@ -246,10 +260,9 @@ class DefaultGenome(BaseGenome):
for i in order:
if i in input_idx:
if sympy_input_transform is not None:
nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](symbols[i])
else:
nodes_exprs[symbols[i]] = symbols[i]
nodes_exprs[symbols[-i - 1]] = symbols[-i - 1] # origin equal to its symbol
nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](symbols[-i - 1]) # normed i
else:
in_conns = [c for c in network["conns"] if c[1] == i]
node_inputs = []
@@ -270,12 +283,13 @@ class DefaultGenome(BaseGenome):
is_output_node=(i in output_idx),
)
args_symbols.update(a_s)
if i in output_idx and sympy_output_transform is not None:
nodes_exprs[symbols[i]] = sympy_output_transform(
nodes_exprs[symbols[i]]
)
input_symbols = [v for k, v in symbols.items() if k in input_idx]
input_symbols = [symbols[-i - 1] for i in input_idx]
reduced_exprs = nodes_exprs.copy()
for i in order:
reduced_exprs[symbols[i]] = reduced_exprs[symbols[i]].subs(reduced_exprs)
@@ -299,7 +313,9 @@ class DefaultGenome(BaseGenome):
fixed_args_output_funcs.append(f)
forward_func = lambda inputs: jnp.array([f(inputs) for f in fixed_args_output_funcs])
forward_func = lambda inputs: jnp.array(
[f(inputs) for f in fixed_args_output_funcs]
)
return (
symbols,