complete normal neat algorithm
This commit is contained in:
0
test/__init__.py
Normal file
0
test/__init__.py
Normal file
0
test/unit/__init__.py
Normal file
0
test/unit/__init__.py
Normal file
36
test/unit/test_utils.py
Normal file
36
test/unit/test_utils.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import pytest
|
||||
import jax
|
||||
|
||||
from algorithm.neat.utils import *
|
||||
|
||||
|
||||
def test_unflatten():
|
||||
nodes = jnp.array([
|
||||
[0, 0, 0, 0],
|
||||
[1, 1, 1, 1],
|
||||
[2, 2, 2, 2],
|
||||
[3, 3, 3, 3],
|
||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
||||
])
|
||||
|
||||
|
||||
conns = jnp.array([
|
||||
[0, 1, True, 0.1, 0.11],
|
||||
[0, 2, False, 0.2, 0.22],
|
||||
[1, 2, True, 0.3, 0.33],
|
||||
[1, 3, False, 0.4, 0.44],
|
||||
])
|
||||
|
||||
res = unflatten_connections(nodes, conns)
|
||||
|
||||
assert jnp.all(res[:, 0, 1] == jnp.array([True, 0.1, 0.11]))
|
||||
assert jnp.all(res[:, 0, 2] == jnp.array([False, 0.2, 0.22]))
|
||||
assert jnp.all(res[:, 1, 2] == jnp.array([True, 0.3, 0.33]))
|
||||
assert jnp.all(res[:, 1, 3] == jnp.array([False, 0.4, 0.44]))
|
||||
|
||||
# Create a mask that excludes the indices we've already checked
|
||||
mask = jnp.ones(res.shape, dtype=bool)
|
||||
mask = mask.at[:, [0, 0, 1, 1], [1, 2, 2, 3]].set(False)
|
||||
|
||||
# Ensure all other places are jnp.nan
|
||||
assert jnp.all(jnp.isnan(res[mask]))
|
||||
Reference in New Issue
Block a user