use black format all files;
remove "return state" for functions which will be executed in vmap; recover randkey as args in mutation methods
This commit is contained in:
@@ -7,21 +7,25 @@ import jax, jax.numpy as jnp
|
||||
def test_default():
|
||||
|
||||
# index, bias, response, activation, aggregation
|
||||
nodes = jnp.array([
|
||||
[0, 0, 1, 0, 0], # in[0]
|
||||
[1, 0, 1, 0, 0], # in[1]
|
||||
[2, 0.5, 1, 0, 0], # out[0],
|
||||
[3, 1, 1, 0, 0], # hidden[0],
|
||||
[4, -1, 1, 0, 0], # hidden[1],
|
||||
])
|
||||
nodes = jnp.array(
|
||||
[
|
||||
[0, 0, 1, 0, 0], # in[0]
|
||||
[1, 0, 1, 0, 0], # in[1]
|
||||
[2, 0.5, 1, 0, 0], # out[0],
|
||||
[3, 1, 1, 0, 0], # hidden[0],
|
||||
[4, -1, 1, 0, 0], # hidden[1],
|
||||
]
|
||||
)
|
||||
|
||||
# in_node, out_node, enable, weight
|
||||
conns = jnp.array([
|
||||
[0, 3, 1, 0.5], # in[0] -> hidden[0]
|
||||
[1, 4, 1, 0.5], # in[1] -> hidden[1]
|
||||
[3, 2, 1, 0.5], # hidden[0] -> out[0]
|
||||
[4, 2, 1, 0.5], # hidden[1] -> out[0]
|
||||
])
|
||||
conns = jnp.array(
|
||||
[
|
||||
[0, 3, 1, 0.5], # in[0] -> hidden[0]
|
||||
[1, 4, 1, 0.5], # in[1] -> hidden[1]
|
||||
[3, 2, 1, 0.5], # hidden[0] -> out[0]
|
||||
[4, 2, 1, 0.5], # hidden[1] -> out[0]
|
||||
]
|
||||
)
|
||||
|
||||
genome = DefaultGenome(
|
||||
num_inputs=2,
|
||||
@@ -30,34 +34,37 @@ def test_default():
|
||||
max_conns=4,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_default=Act.identity,
|
||||
activation_options=(Act.identity, ),
|
||||
activation_options=(Act.identity,),
|
||||
aggregation_default=Agg.sum,
|
||||
aggregation_options=(Agg.sum, ),
|
||||
aggregation_options=(Agg.sum,),
|
||||
),
|
||||
)
|
||||
|
||||
state = genome.setup(State(randkey=jax.random.key(0)))
|
||||
|
||||
state, *transformed = genome.transform(state, nodes, conns)
|
||||
print(*transformed, sep='\n')
|
||||
transformed = genome.transform(state, nodes, conns)
|
||||
print(*transformed, sep="\n")
|
||||
|
||||
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
|
||||
state, outputs = jax.jit(jax.vmap(genome.forward,
|
||||
in_axes=(None, 0, None), out_axes=(None, 0)))(state, inputs, transformed)
|
||||
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
|
||||
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))(
|
||||
state, inputs, transformed
|
||||
)
|
||||
print(outputs)
|
||||
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
|
||||
# expected: [[0.5], [0.75], [0.75], [1]]
|
||||
|
||||
print('\n-------------------------------------------------------\n')
|
||||
print("\n-------------------------------------------------------\n")
|
||||
|
||||
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
|
||||
print(conns)
|
||||
|
||||
state, *transformed = genome.transform(state, nodes, conns)
|
||||
print(*transformed, sep='\n')
|
||||
transformed = genome.transform(state, nodes, conns)
|
||||
print(*transformed, sep="\n")
|
||||
|
||||
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
|
||||
state, outputs = jax.vmap(genome.forward, in_axes=(None, 0, None), out_axes=(None, 0))(state, inputs, transformed)
|
||||
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
|
||||
outputs = jax.vmap(genome.forward, in_axes=(None, 0, None))(
|
||||
state, inputs, transformed
|
||||
)
|
||||
print(outputs)
|
||||
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
|
||||
# expected: [[0.5], [0.75], [0.5], [0.75]]
|
||||
@@ -66,21 +73,25 @@ def test_default():
|
||||
def test_recurrent():
|
||||
|
||||
# index, bias, response, activation, aggregation
|
||||
nodes = jnp.array([
|
||||
[0, 0, 1, 0, 0], # in[0]
|
||||
[1, 0, 1, 0, 0], # in[1]
|
||||
[2, 0.5, 1, 0, 0], # out[0],
|
||||
[3, 1, 1, 0, 0], # hidden[0],
|
||||
[4, -1, 1, 0, 0], # hidden[1],
|
||||
])
|
||||
nodes = jnp.array(
|
||||
[
|
||||
[0, 0, 1, 0, 0], # in[0]
|
||||
[1, 0, 1, 0, 0], # in[1]
|
||||
[2, 0.5, 1, 0, 0], # out[0],
|
||||
[3, 1, 1, 0, 0], # hidden[0],
|
||||
[4, -1, 1, 0, 0], # hidden[1],
|
||||
]
|
||||
)
|
||||
|
||||
# in_node, out_node, enable, weight
|
||||
conns = jnp.array([
|
||||
[0, 3, 1, 0.5], # in[0] -> hidden[0]
|
||||
[1, 4, 1, 0.5], # in[1] -> hidden[1]
|
||||
[3, 2, 1, 0.5], # hidden[0] -> out[0]
|
||||
[4, 2, 1, 0.5], # hidden[1] -> out[0]
|
||||
])
|
||||
conns = jnp.array(
|
||||
[
|
||||
[0, 3, 1, 0.5], # in[0] -> hidden[0]
|
||||
[1, 4, 1, 0.5], # in[1] -> hidden[1]
|
||||
[3, 2, 1, 0.5], # hidden[0] -> out[0]
|
||||
[4, 2, 1, 0.5], # hidden[1] -> out[0]
|
||||
]
|
||||
)
|
||||
|
||||
genome = RecurrentGenome(
|
||||
num_inputs=2,
|
||||
@@ -89,35 +100,38 @@ def test_recurrent():
|
||||
max_conns=4,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_default=Act.identity,
|
||||
activation_options=(Act.identity, ),
|
||||
activation_options=(Act.identity,),
|
||||
aggregation_default=Agg.sum,
|
||||
aggregation_options=(Agg.sum, ),
|
||||
aggregation_options=(Agg.sum,),
|
||||
),
|
||||
activate_time=3,
|
||||
)
|
||||
|
||||
state = genome.setup(State(randkey=jax.random.key(0)))
|
||||
|
||||
state, *transformed = genome.transform(state, nodes, conns)
|
||||
print(*transformed, sep='\n')
|
||||
transformed = genome.transform(state, nodes, conns)
|
||||
print(*transformed, sep="\n")
|
||||
|
||||
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
|
||||
state, outputs = jax.jit(jax.vmap(genome.forward,
|
||||
in_axes=(None, 0, None), out_axes=(None, 0)))(state, inputs, transformed)
|
||||
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
|
||||
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))(
|
||||
state, inputs, transformed
|
||||
)
|
||||
print(outputs)
|
||||
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
|
||||
# expected: [[0.5], [0.75], [0.75], [1]]
|
||||
|
||||
print('\n-------------------------------------------------------\n')
|
||||
print("\n-------------------------------------------------------\n")
|
||||
|
||||
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
|
||||
print(conns)
|
||||
|
||||
state, *transformed = genome.transform(state, nodes, conns)
|
||||
print(*transformed, sep='\n')
|
||||
transformed = genome.transform(state, nodes, conns)
|
||||
print(*transformed, sep="\n")
|
||||
|
||||
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
|
||||
state, outputs = jax.vmap(genome.forward, in_axes=(None, 0, None), out_axes=(None, 0))(state, inputs, transformed)
|
||||
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
|
||||
outputs = jax.vmap(genome.forward, in_axes=(None, 0, None))(
|
||||
state, inputs, transformed
|
||||
)
|
||||
print(outputs)
|
||||
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
|
||||
# expected: [[0.5], [0.75], [0.5], [0.75]]
|
||||
# expected: [[0.5], [0.75], [0.5], [0.75]]
|
||||
|
||||
Reference in New Issue
Block a user