Files

45 lines
1.2 KiB
Python

import jax, jax.numpy as jnp
from tensorneat.genome import DefaultGenome
from tensorneat.common import *
from tensorneat.common.functions import SympySigmoid
if __name__ == "__main__":
genome = DefaultGenome(
num_inputs=3,
num_outputs=1,
max_nodes=50,
max_conns=500,
output_transform=ACT.sigmoid,
)
state = genome.setup()
randkey = jax.random.PRNGKey(42)
nodes, conns = genome.initialize(state, randkey)
network = genome.network_dict(state, nodes, conns)
input_idx, output_idx = genome.get_input_idx(), genome.get_output_idx()
res = genome.sympy_func(state, network, sympy_input_transform=lambda x: 999*x, sympy_output_transform=SympySigmoid)
(symbols,
args_symbols,
input_symbols,
nodes_exprs,
output_exprs,
forward_func,) = res
print(symbols)
print(output_exprs[0].subs(args_symbols))
inputs = jnp.zeros(3)
print(forward_func(inputs))
print(genome.forward(state, genome.transform(state, nodes, conns), inputs))
print(AGG.sympy_module("jax"))
print(AGG.sympy_module("numpy"))
print(ACT.sympy_module("jax"))
print(ACT.sympy_module("numpy"))