complete fully stateful!
use black to format all files!
This commit is contained in:
@@ -2,6 +2,7 @@ from algorithm.neat import *
|
||||
from utils import Act, Agg, State
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse
|
||||
|
||||
|
||||
def test_default():
|
||||
@@ -135,3 +136,29 @@ def test_recurrent():
|
||||
print(outputs)
|
||||
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
|
||||
# expected: [[0.5], [0.75], [0.5], [0.75]]
|
||||
|
||||
|
||||
def test_random_initialize():
|
||||
genome = DefaultGenome(
|
||||
num_inputs=2,
|
||||
num_outputs=1,
|
||||
max_nodes=5,
|
||||
max_conns=4,
|
||||
node_gene=NodeGeneWithoutResponse(
|
||||
activation_default=Act.identity,
|
||||
activation_options=(Act.identity,),
|
||||
aggregation_default=Agg.sum,
|
||||
aggregation_options=(Agg.sum,),
|
||||
),
|
||||
)
|
||||
state = genome.setup()
|
||||
key = jax.random.PRNGKey(0)
|
||||
nodes, conns = genome.initialize(state, key)
|
||||
transformed = genome.transform(state, nodes, conns)
|
||||
print(*transformed, sep="\n")
|
||||
|
||||
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
|
||||
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))(
|
||||
state, inputs, transformed
|
||||
)
|
||||
print(outputs)
|
||||
|
||||
Reference in New Issue
Block a user