refactor folder locations
This commit is contained in:
27
test/test_genome.py
Normal file
27
test/test_genome.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import jax
|
||||
from algorithm.neat import *
|
||||
|
||||
genome = DefaultGenome(
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
max_nodes=5,
|
||||
max_conns=10,
|
||||
)
|
||||
|
||||
|
||||
def test_output_work():
|
||||
randkey = jax.random.PRNGKey(0)
|
||||
state = genome.setup()
|
||||
nodes, conns = genome.initialize(state, randkey)
|
||||
transformed = genome.transform(state, nodes, conns)
|
||||
inputs = jax.random.normal(randkey, (3,))
|
||||
output = genome.forward(state, transformed, inputs)
|
||||
print(output)
|
||||
|
||||
batch_inputs = jax.random.normal(randkey, (10, 3))
|
||||
batch_output = jax.vmap(genome.forward, in_axes=(None, None, 0))(
|
||||
state, transformed, batch_inputs
|
||||
)
|
||||
print(batch_output)
|
||||
|
||||
assert True
|
||||
Reference in New Issue
Block a user