update functions. Visualize, Interpretable and with evox
This commit is contained in:
@@ -1,16 +1,16 @@
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from algorithm.neat import *
|
||||
from algorithm.neat.genome.dense import DenseInitialize
|
||||
from utils.graph import topological_sort_python
|
||||
from tensorneat.genome import DefaultGenome
|
||||
from tensorneat.common import *
|
||||
from tensorneat.common.functions import SympySigmoid
|
||||
|
||||
if __name__ == "__main__":
|
||||
genome = DenseInitialize(
|
||||
genome = DefaultGenome(
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
max_nodes=50,
|
||||
max_conns=500,
|
||||
output_transform=ACT.sigmoid,
|
||||
)
|
||||
|
||||
state = genome.setup()
|
||||
@@ -22,7 +22,7 @@ if __name__ == "__main__":
|
||||
|
||||
input_idx, output_idx = genome.get_input_idx(), genome.get_output_idx()
|
||||
|
||||
res = genome.sympy_func(state, network, sympy_input_transform=lambda x: 999999999*x, sympy_output_transform=SympyStandardSigmoid)
|
||||
res = genome.sympy_func(state, network, sympy_input_transform=lambda x: 999*x, sympy_output_transform=SympySigmoid)
|
||||
(symbols,
|
||||
args_symbols,
|
||||
input_symbols,
|
||||
@@ -35,3 +35,11 @@ if __name__ == "__main__":
|
||||
|
||||
inputs = jnp.zeros(3)
|
||||
print(forward_func(inputs))
|
||||
|
||||
print(genome.forward(state, genome.transform(state, nodes, conns), inputs))
|
||||
|
||||
print(AGG.sympy_module("jax"))
|
||||
print(AGG.sympy_module("numpy"))
|
||||
|
||||
print(ACT.sympy_module("jax"))
|
||||
print(ACT.sympy_module("numpy"))
|
||||
Reference in New Issue
Block a user