complete fully stateful!

use black to format all files!
This commit is contained in:
wls2002
2024-05-26 18:08:43 +08:00
parent cf69b916af
commit 18c3d44c79
41 changed files with 620 additions and 495 deletions

View File

@@ -36,7 +36,9 @@ def main():
elite_mask = jnp.zeros((1000,), dtype=jnp.bool_)
elite_mask = elite_mask.at[:5].set(1)
state = algorithm.create_next_generation(jax.random.key(0), state, winner, losser, elite_mask)
state = algorithm.create_next_generation(
jax.random.key(0), state, winner, losser, elite_mask
)
pop_nodes, pop_conns = algorithm.species.ask(state.species)
transforms = batch_transform(pop_nodes, pop_conns)
@@ -48,5 +50,5 @@ def main():
print(_)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@@ -19,7 +19,7 @@ def main():
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
)
),
)
transformed = genome.transform(nodes, conns)
@@ -35,7 +35,7 @@ def main():
print(output)
if __name__ == '__main__':
if __name__ == "__main__":
a = jnp.array([1, 3, 5, 6, 8])
b = jnp.array([1, 2, 3])
print(jnp.isin(a, b))

View File

@@ -2,6 +2,7 @@ from algorithm.neat import *
from utils import Act, Agg, State
import jax, jax.numpy as jnp
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse
def test_default():
@@ -135,3 +136,29 @@ def test_recurrent():
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
# expected: [[0.5], [0.75], [0.5], [0.75]]
def test_random_initialize():
genome = DefaultGenome(
num_inputs=2,
num_outputs=1,
max_nodes=5,
max_conns=4,
node_gene=NodeGeneWithoutResponse(
activation_default=Act.identity,
activation_options=(Act.identity,),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum,),
),
)
state = genome.setup()
key = jax.random.PRNGKey(0)
nodes, conns = genome.initialize(state, key)
transformed = genome.transform(state, nodes, conns)
print(*transformed, sep="\n")
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))(
state, inputs, transformed
)
print(outputs)

View File

@@ -19,11 +19,11 @@ def main():
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
)
),
)
transformed = genome.transform(nodes, conns)
print(*transformed, sep='\n')
print(*transformed, sep="\n")
key = jax.random.key(0)
dummy_input = jnp.zeros((8,))
@@ -31,5 +31,5 @@ def main():
print(output)
if __name__ == '__main__':
if __name__ == "__main__":
main()