add input_transform and update_input_transform;

change the args for genome.forward.
Origin: (state, inputs, transformed)
New: (state, transformed, inputs)
This commit is contained in:
wls2002
2024-06-03 10:53:15 +08:00
parent a07a3b1cb2
commit edfb0596e7
16 changed files with 185 additions and 221 deletions

View File

@@ -1,132 +1,27 @@
import jax
from algorithm.neat import *
from utils import Act, Agg, State
import jax, jax.numpy as jnp
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse
genome = DefaultGenome(
num_inputs=3,
num_outputs=1,
max_nodes=5,
max_conns=10,
)
def test_default():
# index, bias, response, activation, aggregation
nodes = jnp.array(
[
[0, 0, 1, 0, 0], # in[0]
[1, 0, 1, 0, 0], # in[1]
[2, 0.5, 1, 0, 0], # out[0],
[3, 1, 1, 0, 0], # hidden[0],
[4, -1, 1, 0, 0], # hidden[1],
]
)
# in_node, out_node, enable, weight
conns = jnp.array(
[
[0, 3, 0.5], # in[0] -> hidden[0]
[1, 4, 0.5], # in[1] -> hidden[1]
[3, 2, 0.5], # hidden[0] -> out[0]
[4, 2, 0.5], # hidden[1] -> out[0]
]
)
genome = DefaultGenome(
num_inputs=2,
num_outputs=1,
max_nodes=5,
max_conns=4,
node_gene=DefaultNodeGene(
activation_default=Act.identity,
activation_options=(Act.identity,),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum,),
),
)
state = genome.setup(State(randkey=jax.random.key(0)))
transformed = genome.transform(state, nodes, conns)
print(*transformed, sep="\n")
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))(
state, inputs, transformed
)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
# expected: [[0.5], [0.75], [0.75], [1]]
def test_recurrent():
# index, bias, response, activation, aggregation
nodes = jnp.array(
[
[0, 0, 1, 0, 0], # in[0]
[1, 0, 1, 0, 0], # in[1]
[2, 0.5, 1, 0, 0], # out[0],
[3, 1, 1, 0, 0], # hidden[0],
[4, -1, 1, 0, 0], # hidden[1],
]
)
# in_node, out_node, enable, weight
conns = jnp.array(
[
[0, 3, 0.5], # in[0] -> hidden[0]
[1, 4, 0.5], # in[1] -> hidden[1]
[3, 2, 0.5], # hidden[0] -> out[0]
[4, 2, 0.5], # hidden[1] -> out[0]
]
)
genome = RecurrentGenome(
num_inputs=2,
num_outputs=1,
max_nodes=5,
max_conns=4,
node_gene=DefaultNodeGene(
activation_default=Act.identity,
activation_options=(Act.identity,),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum,),
),
activate_time=3,
)
state = genome.setup(State(randkey=jax.random.key(0)))
transformed = genome.transform(state, nodes, conns)
print(*transformed, sep="\n")
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))(
state, inputs, transformed
)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
# expected: [[0.5], [0.75], [0.75], [1]]
def test_random_initialize():
genome = DefaultGenome(
num_inputs=2,
num_outputs=1,
max_nodes=5,
max_conns=4,
node_gene=NodeGeneWithoutResponse(
activation_default=Act.identity,
activation_options=(Act.identity,),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum,),
),
)
def test_output_work():
randkey = jax.random.PRNGKey(0)
state = genome.setup()
key = jax.random.PRNGKey(0)
nodes, conns = genome.initialize(state, key)
nodes, conns = genome.initialize(state, randkey)
transformed = genome.transform(state, nodes, conns)
print(*transformed, sep="\n")
inputs = jax.random.normal(randkey, (3,))
output = genome.forward(state, transformed, inputs)
print(output)
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))(
state, inputs, transformed
batch_inputs = jax.random.normal(randkey, (10, 3))
batch_output = jax.vmap(genome.forward, in_axes=(None, None, 0))(
state, transformed, batch_inputs
)
print(outputs)
print(batch_output)
assert True