fix bugs
This commit is contained in:
@@ -2,7 +2,7 @@ import numpy as np
|
||||
import jax, jax.numpy as jnp
|
||||
from ..gene import BaseNodeGene, BaseConnGene
|
||||
from ..ga import BaseMutation, BaseCrossover
|
||||
from utils import State, StatefulBaseClass, topological_sort_python
|
||||
from utils import State, StatefulBaseClass, topological_sort_python, hash_array
|
||||
|
||||
|
||||
class BaseGenome(StatefulBaseClass):
|
||||
@@ -255,10 +255,14 @@ class BaseGenome(StatefulBaseClass):
|
||||
|
||||
nx.draw(
|
||||
G,
|
||||
with_labels=True,
|
||||
pos=rotated_pos,
|
||||
node_size=node_sizes,
|
||||
node_color=node_colors,
|
||||
**kwargs,
|
||||
)
|
||||
plt.savefig(save_path, dpi=save_dpi)
|
||||
|
||||
def hash(self, nodes, conns):
|
||||
nodes_hashs = jax.vmap(hash_array)(nodes)
|
||||
conns_hashs = jax.vmap(hash_array)(conns)
|
||||
return hash_array(jnp.concatenate([nodes_hashs, conns_hashs]))
|
||||
|
||||
@@ -210,7 +210,14 @@ class DefaultGenome(BaseGenome):
|
||||
new_transformed,
|
||||
)
|
||||
|
||||
def sympy_func(self, state, network, sympy_input_transform=None, sympy_output_transform=None, backend="jax"):
|
||||
def sympy_func(
|
||||
self,
|
||||
state,
|
||||
network,
|
||||
sympy_input_transform=None,
|
||||
sympy_output_transform=None,
|
||||
backend="jax",
|
||||
):
|
||||
|
||||
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
|
||||
module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP
|
||||
@@ -219,6 +226,10 @@ class DefaultGenome(BaseGenome):
|
||||
warnings.warn(
|
||||
"genome.input_transform is not None but sympy_input_transform is None!"
|
||||
)
|
||||
|
||||
if sympy_input_transform is None:
|
||||
sympy_input_transform = lambda x: x
|
||||
|
||||
if sympy_input_transform is not None:
|
||||
if not isinstance(sympy_input_transform, list):
|
||||
sympy_input_transform = [sympy_input_transform] * self.num_inputs
|
||||
@@ -231,11 +242,14 @@ class DefaultGenome(BaseGenome):
|
||||
input_idx = self.get_input_idx()
|
||||
output_idx = self.get_output_idx()
|
||||
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
|
||||
hidden_idx = [i for i in network["nodes"] if i not in input_idx and i not in output_idx]
|
||||
hidden_idx = [
|
||||
i for i in network["nodes"] if i not in input_idx and i not in output_idx
|
||||
]
|
||||
symbols = {}
|
||||
for i in network["nodes"]:
|
||||
if i in input_idx:
|
||||
symbols[i] = sp.Symbol(f"i{i - min(input_idx)}")
|
||||
symbols[-i - 1] = sp.Symbol(f"i{i - min(input_idx)}") # origin_i
|
||||
symbols[i] = sp.Symbol(f"norm{i - min(input_idx)}")
|
||||
elif i in output_idx:
|
||||
symbols[i] = sp.Symbol(f"o{i - min(output_idx)}")
|
||||
else: # hidden
|
||||
@@ -246,10 +260,9 @@ class DefaultGenome(BaseGenome):
|
||||
for i in order:
|
||||
|
||||
if i in input_idx:
|
||||
if sympy_input_transform is not None:
|
||||
nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](symbols[i])
|
||||
else:
|
||||
nodes_exprs[symbols[i]] = symbols[i]
|
||||
nodes_exprs[symbols[-i - 1]] = symbols[-i - 1] # origin equal to its symbol
|
||||
nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](symbols[-i - 1]) # normed i
|
||||
|
||||
else:
|
||||
in_conns = [c for c in network["conns"] if c[1] == i]
|
||||
node_inputs = []
|
||||
@@ -270,12 +283,13 @@ class DefaultGenome(BaseGenome):
|
||||
is_output_node=(i in output_idx),
|
||||
)
|
||||
args_symbols.update(a_s)
|
||||
|
||||
if i in output_idx and sympy_output_transform is not None:
|
||||
nodes_exprs[symbols[i]] = sympy_output_transform(
|
||||
nodes_exprs[symbols[i]]
|
||||
)
|
||||
|
||||
input_symbols = [v for k, v in symbols.items() if k in input_idx]
|
||||
input_symbols = [symbols[-i - 1] for i in input_idx]
|
||||
reduced_exprs = nodes_exprs.copy()
|
||||
for i in order:
|
||||
reduced_exprs[symbols[i]] = reduced_exprs[symbols[i]].subs(reduced_exprs)
|
||||
@@ -299,7 +313,9 @@ class DefaultGenome(BaseGenome):
|
||||
|
||||
fixed_args_output_funcs.append(f)
|
||||
|
||||
forward_func = lambda inputs: jnp.array([f(inputs) for f in fixed_args_output_funcs])
|
||||
forward_func = lambda inputs: jnp.array(
|
||||
[f(inputs) for f in fixed_args_output_funcs]
|
||||
)
|
||||
|
||||
return (
|
||||
symbols,
|
||||
|
||||
Reference in New Issue
Block a user