add gene type RNN
This commit is contained in:
@@ -1,7 +1,5 @@
|
||||
import pytest
|
||||
import jax
|
||||
|
||||
from algorithm.neat.utils import *
|
||||
import jax.numpy as jnp
|
||||
from algorithm.neat.utils import unflatten_connections
|
||||
|
||||
|
||||
def test_unflatten():
|
||||
@@ -13,7 +11,6 @@ def test_unflatten():
|
||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
||||
])
|
||||
|
||||
|
||||
conns = jnp.array([
|
||||
[0, 1, True, 0.1, 0.11],
|
||||
[0, 2, False, 0.2, 0.22],
|
||||
@@ -33,4 +30,4 @@ def test_unflatten():
|
||||
mask = mask.at[:, [0, 0, 1, 1], [1, 2, 2, 3]].set(False)
|
||||
|
||||
# Ensure all other places are jnp.nan
|
||||
assert jnp.all(jnp.isnan(res[mask]))
|
||||
assert jnp.all(jnp.isnan(res[mask]))
|
||||
|
||||
Reference in New Issue
Block a user