All function with state will update the state and return it.
Remove randkey args in functions with state, since it can attach the randkey by states.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from algorithm.neat import *
|
||||
from utils import Act, Agg
|
||||
from utils import Act, Agg, State
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
@@ -36,11 +36,14 @@ def test_default():
|
||||
),
|
||||
)
|
||||
|
||||
transformed = genome.transform(nodes, conns)
|
||||
state = genome.setup(State(randkey=jax.random.key(0)))
|
||||
|
||||
state, *transformed = genome.transform(state, nodes, conns)
|
||||
print(*transformed, sep='\n')
|
||||
|
||||
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
|
||||
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed)
|
||||
state, outputs = jax.jit(jax.vmap(genome.forward,
|
||||
in_axes=(None, 0, None), out_axes=(None, 0)))(state, inputs, transformed)
|
||||
print(outputs)
|
||||
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
|
||||
# expected: [[0.5], [0.75], [0.75], [1]]
|
||||
@@ -50,11 +53,11 @@ def test_default():
|
||||
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
|
||||
print(conns)
|
||||
|
||||
transformed = genome.transform(nodes, conns)
|
||||
state, *transformed = genome.transform(state, nodes, conns)
|
||||
print(*transformed, sep='\n')
|
||||
|
||||
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
|
||||
outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed)
|
||||
state, outputs = jax.vmap(genome.forward, in_axes=(None, 0, None), out_axes=(None, 0))(state, inputs, transformed)
|
||||
print(outputs)
|
||||
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
|
||||
# expected: [[0.5], [0.75], [0.5], [0.75]]
|
||||
@@ -93,11 +96,14 @@ def test_recurrent():
|
||||
activate_time=3,
|
||||
)
|
||||
|
||||
transformed = genome.transform(nodes, conns)
|
||||
state = genome.setup(State(randkey=jax.random.key(0)))
|
||||
|
||||
state, *transformed = genome.transform(state, nodes, conns)
|
||||
print(*transformed, sep='\n')
|
||||
|
||||
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
|
||||
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed)
|
||||
state, outputs = jax.jit(jax.vmap(genome.forward,
|
||||
in_axes=(None, 0, None), out_axes=(None, 0)))(state, inputs, transformed)
|
||||
print(outputs)
|
||||
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
|
||||
# expected: [[0.5], [0.75], [0.75], [1]]
|
||||
@@ -107,11 +113,11 @@ def test_recurrent():
|
||||
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
|
||||
print(conns)
|
||||
|
||||
transformed = genome.transform(nodes, conns)
|
||||
state, *transformed = genome.transform(state, nodes, conns)
|
||||
print(*transformed, sep='\n')
|
||||
|
||||
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
|
||||
outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed)
|
||||
state, outputs = jax.vmap(genome.forward, in_axes=(None, 0, None), out_axes=(None, 0))(state, inputs, transformed)
|
||||
print(outputs)
|
||||
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
|
||||
# expected: [[0.5], [0.75], [0.5], [0.75]]
|
||||
Reference in New Issue
Block a user