This commit is contained in:
wls2002
2024-06-20 16:32:52 +08:00
parent 9f72813c35
commit 075460f896
17 changed files with 224 additions and 140 deletions

View File

@@ -1,14 +1,14 @@
import jax, jax.numpy as jnp
from algorithm.neat import *
from algorithm.neat.genome.hidden import AdvanceInitialize
from algorithm.neat.genome.dense import DenseInitialize
from utils.graph import topological_sort_python
from utils import *
if __name__ == '__main__':
genome = AdvanceInitialize(
num_inputs=17,
num_outputs=6,
hidden_cnt=8,
if __name__ == "__main__":
genome = DenseInitialize(
num_inputs=3,
num_outputs=1,
max_nodes=50,
max_conns=500,
)
@@ -19,16 +19,19 @@ if __name__ == '__main__':
nodes, conns = genome.initialize(state, randkey)
network = genome.network_dict(state, nodes, conns)
print(set(network["nodes"]), set(network["conns"]))
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
print(order)
input_idx, output_idx = genome.get_input_idx(), genome.get_output_idx()
print(input_idx, output_idx)
print(genome.repr(state, nodes, conns))
print(network)
res = genome.sympy_func(state, network, sympy_input_transform=lambda x: 999999999*x, sympy_output_transform=SympyStandardSigmoid)
(symbols,
args_symbols,
input_symbols,
nodes_exprs,
output_exprs,
forward_func,) = res
res = genome.sympy_func(state, network, precision=3)
print(res)
print(symbols)
print(output_exprs[0].subs(args_symbols))
inputs = jnp.zeros(3)
print(forward_func(inputs))