28 lines
671 B
Python
28 lines
671 B
Python
import jax
|
|
from algorithm.neat import *
|
|
|
|
genome = DefaultGenome(
|
|
num_inputs=3,
|
|
num_outputs=1,
|
|
max_nodes=5,
|
|
max_conns=10,
|
|
)
|
|
|
|
|
|
def test_output_work():
|
|
randkey = jax.random.PRNGKey(0)
|
|
state = genome.setup()
|
|
nodes, conns = genome.initialize(state, randkey)
|
|
transformed = genome.transform(state, nodes, conns)
|
|
inputs = jax.random.normal(randkey, (3,))
|
|
output = genome.forward(state, transformed, inputs)
|
|
print(output)
|
|
|
|
batch_inputs = jax.random.normal(randkey, (10, 3))
|
|
batch_output = jax.vmap(genome.forward, in_axes=(None, None, 0))(
|
|
state, transformed, batch_inputs
|
|
)
|
|
print(batch_output)
|
|
|
|
assert True
|