Files
tensorneat-mend/test/nan_fitness.py
2024-07-12 02:25:57 +08:00

43 lines
1005 B
Python

import jax, jax.numpy as jnp
from tensorneat.common import ACT
from algorithm.neat import *
import numpy as np
def main():
node_path = "../examples/brax/nan_node.npy"
conn_path = "../examples/brax/nan_conn.npy"
nodes = np.load(node_path)
conns = np.load(conn_path)
nodes, conns = jax.device_put([nodes, conns])
genome = DefaultGenome(
num_inputs=8,
num_outputs=2,
max_nodes=20,
max_conns=20,
node_gene=DefaultNodeGene(
activation_options=(ACT.tanh,),
activation_default=ACT.tanh,
),
)
transformed = genome.transform(nodes, conns)
seq, nodes, conns = transformed
print(seq)
exit(0)
# print(*transformed, sep='\n')
key = jax.random.key(0)
dummy_input = jnp.zeros((8,))
output = genome.forward(dummy_input, transformed)
print(output)
if __name__ == "__main__":
a = jnp.array([1, 3, 5, 6, 8])
b = jnp.array([1, 2, 3])
print(jnp.isin(a, b))
# main()