All function with state will update the state and return it.

Remove randkey args in functions with state, since it can attach the randkey by states.
This commit is contained in:
wls2002
2024-05-25 20:45:57 +08:00
parent 5626fddf41
commit 79d53ea7af
12 changed files with 84 additions and 70 deletions

View File

@@ -1,5 +1,5 @@
from algorithm.neat import *
from utils import Act, Agg
from utils import Act, Agg, State
import jax, jax.numpy as jnp
@@ -36,11 +36,14 @@ def test_default():
),
)
transformed = genome.transform(nodes, conns)
state = genome.setup(State(randkey=jax.random.key(0)))
state, *transformed = genome.transform(state, nodes, conns)
print(*transformed, sep='\n')
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed)
state, outputs = jax.jit(jax.vmap(genome.forward,
in_axes=(None, 0, None), out_axes=(None, 0)))(state, inputs, transformed)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
# expected: [[0.5], [0.75], [0.75], [1]]
@@ -50,11 +53,11 @@ def test_default():
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
print(conns)
transformed = genome.transform(nodes, conns)
state, *transformed = genome.transform(state, nodes, conns)
print(*transformed, sep='\n')
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed)
state, outputs = jax.vmap(genome.forward, in_axes=(None, 0, None), out_axes=(None, 0))(state, inputs, transformed)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
# expected: [[0.5], [0.75], [0.5], [0.75]]
@@ -93,11 +96,14 @@ def test_recurrent():
activate_time=3,
)
transformed = genome.transform(nodes, conns)
state = genome.setup(State(randkey=jax.random.key(0)))
state, *transformed = genome.transform(state, nodes, conns)
print(*transformed, sep='\n')
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed)
state, outputs = jax.jit(jax.vmap(genome.forward,
in_axes=(None, 0, None), out_axes=(None, 0)))(state, inputs, transformed)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
# expected: [[0.5], [0.75], [0.75], [1]]
@@ -107,11 +113,11 @@ def test_recurrent():
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
print(conns)
transformed = genome.transform(nodes, conns)
state, *transformed = genome.transform(state, nodes, conns)
print(*transformed, sep='\n')
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed)
state, outputs = jax.vmap(genome.forward, in_axes=(None, 0, None), out_axes=(None, 0))(state, inputs, transformed)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
# expected: [[0.5], [0.75], [0.5], [0.75]]