add gene type RNN

This commit is contained in:
wls2002
2023-07-19 15:43:49 +08:00
parent 0a2a9fd1be
commit a684e6584d
18 changed files with 248 additions and 129 deletions

View File

@@ -1,7 +1,5 @@
import pytest
import jax
from algorithm.neat.utils import *
import jax.numpy as jnp
from algorithm.neat.utils import unflatten_connections
def test_unflatten():
@@ -13,7 +11,6 @@ def test_unflatten():
[jnp.nan, jnp.nan, jnp.nan, jnp.nan]
])
conns = jnp.array([
[0, 1, True, 0.1, 0.11],
[0, 2, False, 0.2, 0.22],
@@ -33,4 +30,4 @@ def test_unflatten():
mask = mask.at[:, [0, 0, 1, 1], [1, 2, 2, 3]].set(False)
# Ensure all other places are jnp.nan
assert jnp.all(jnp.isnan(res[mask]))
assert jnp.all(jnp.isnan(res[mask]))