implement neat algorithm in jax