refactor folder locations
This commit is contained in:
35
test/test_nan_fitness.py
Normal file
35
test/test_nan_fitness.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from tensorneat.common import Act
|
||||
from algorithm.neat import *
|
||||
import numpy as np
|
||||
|
||||
|
||||
def main():
|
||||
node_path = "../examples/brax/nan_node.npy"
|
||||
conn_path = "../examples/brax/nan_conn.npy"
|
||||
nodes = np.load(node_path)
|
||||
conns = np.load(conn_path)
|
||||
nodes, conns = jax.device_put([nodes, conns])
|
||||
|
||||
genome = DefaultGenome(
|
||||
num_inputs=8,
|
||||
num_outputs=2,
|
||||
max_nodes=20,
|
||||
max_conns=20,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_options=(Act.tanh,),
|
||||
activation_default=Act.tanh,
|
||||
),
|
||||
)
|
||||
|
||||
transformed = genome.transform(nodes, conns)
|
||||
print(*transformed, sep="\n")
|
||||
|
||||
key = jax.random.key(0)
|
||||
dummy_input = jnp.zeros((8,))
|
||||
output = genome.forward(dummy_input, transformed)
|
||||
print(output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user