55 lines
1.5 KiB
Python
55 lines
1.5 KiB
Python
import jax, jax.numpy as jnp
|
|
from tensorneat.common import ACT
|
|
from algorithm.neat import *
|
|
import numpy as np
|
|
|
|
|
|
def main():
|
|
algorithm = NEAT(
|
|
species=DefaultSpecies(
|
|
genome=DefaultGenome(
|
|
num_inputs=3,
|
|
num_outputs=1,
|
|
max_nodes=100,
|
|
max_conns=100,
|
|
),
|
|
pop_size=1000,
|
|
species_size=10,
|
|
compatibility_threshold=3.5,
|
|
),
|
|
mutation=DefaultMutation(
|
|
conn_add=0.4,
|
|
conn_delete=0,
|
|
node_add=0.9,
|
|
node_delete=0,
|
|
),
|
|
)
|
|
|
|
state = algorithm.setup(jax.random.key(0))
|
|
pop_nodes, pop_conns = algorithm.species.ask(state.species)
|
|
|
|
batch_transform = jax.vmap(algorithm.genome.transform)
|
|
batch_forward = jax.vmap(algorithm.forward, in_axes=(None, 0))
|
|
|
|
for _ in range(50):
|
|
winner, losser = jax.random.randint(state.randkey, (2, 1000), 0, 1000)
|
|
elite_mask = jnp.zeros((1000,), dtype=jnp.bool_)
|
|
elite_mask = elite_mask.at[:5].set(1)
|
|
|
|
state = algorithm.create_next_generation(
|
|
jax.random.key(0), state, winner, losser, elite_mask
|
|
)
|
|
pop_nodes, pop_conns = algorithm.species.ask(state.species)
|
|
|
|
transforms = batch_transform(pop_nodes, pop_conns)
|
|
outputs = batch_forward(jnp.array([1, 0, 1]), transforms)
|
|
|
|
try:
|
|
assert not jnp.any(jnp.isnan(outputs))
|
|
except:
|
|
print(_)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|