Files
tensorneat-mend/tensorneat/test/test_nan_fitness.py
2024-07-10 11:24:11 +08:00

36 lines
847 B
Python

import jax, jax.numpy as jnp
from tensorneat.common import Act
from algorithm.neat import *
import numpy as np
def main():
node_path = "../examples/brax/nan_node.npy"
conn_path = "../examples/brax/nan_conn.npy"
nodes = np.load(node_path)
conns = np.load(conn_path)
nodes, conns = jax.device_put([nodes, conns])
genome = DefaultGenome(
num_inputs=8,
num_outputs=2,
max_nodes=20,
max_conns=20,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
),
)
transformed = genome.transform(nodes, conns)
print(*transformed, sep="\n")
key = jax.random.key(0)
dummy_input = jnp.zeros((8,))
output = genome.forward(dummy_input, transformed)
print(output)
if __name__ == "__main__":
main()