remove attr enable for conn

This commit is contained in:
wls2002
2024-05-31 22:06:25 +08:00
parent d6e9ff5d9a
commit 4ad9f0a85a
9 changed files with 43 additions and 108 deletions

View File

@@ -21,10 +21,10 @@ def test_default():
# in_node, out_node, enable, weight
conns = jnp.array(
[
[0, 3, 1, 0.5], # in[0] -> hidden[0]
[1, 4, 1, 0.5], # in[1] -> hidden[1]
[3, 2, 1, 0.5], # hidden[0] -> out[0]
[4, 2, 1, 0.5], # hidden[1] -> out[0]
[0, 3, 0.5], # in[0] -> hidden[0]
[1, 4, 0.5], # in[1] -> hidden[1]
[3, 2, 0.5], # hidden[0] -> out[0]
[4, 2, 0.5], # hidden[1] -> out[0]
]
)
@@ -54,22 +54,6 @@ def test_default():
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
# expected: [[0.5], [0.75], [0.75], [1]]
print("\n-------------------------------------------------------\n")
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
print(conns)
transformed = genome.transform(state, nodes, conns)
print(*transformed, sep="\n")
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
outputs = jax.vmap(genome.forward, in_axes=(None, 0, None))(
state, inputs, transformed
)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
# expected: [[0.5], [0.75], [0.5], [0.75]]
def test_recurrent():
@@ -87,10 +71,10 @@ def test_recurrent():
# in_node, out_node, enable, weight
conns = jnp.array(
[
[0, 3, 1, 0.5], # in[0] -> hidden[0]
[1, 4, 1, 0.5], # in[1] -> hidden[1]
[3, 2, 1, 0.5], # hidden[0] -> out[0]
[4, 2, 1, 0.5], # hidden[1] -> out[0]
[0, 3, 0.5], # in[0] -> hidden[0]
[1, 4, 0.5], # in[1] -> hidden[1]
[3, 2, 0.5], # hidden[0] -> out[0]
[4, 2, 0.5], # hidden[1] -> out[0]
]
)
@@ -121,22 +105,6 @@ def test_recurrent():
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
# expected: [[0.5], [0.75], [0.75], [1]]
print("\n-------------------------------------------------------\n")
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
print(conns)
transformed = genome.transform(state, nodes, conns)
print(*transformed, sep="\n")
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
outputs = jax.vmap(genome.forward, in_axes=(None, 0, None))(
state, inputs, transformed
)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
# expected: [[0.5], [0.75], [0.5], [0.75]]
def test_random_initialize():
genome = DefaultGenome(