change a lot a lot a lot!!!!!!!

This commit is contained in:
wls2002
2023-07-24 02:16:02 +08:00
parent 48f90c7eef
commit ac295c1921
49 changed files with 1138 additions and 1460 deletions

View File

View File

View File

@@ -1,56 +0,0 @@
import numpy as np
from algorithm.hyperneat.substrate.tools import cartesian_product
def test01():
keys1 = np.array([1, 2, 3])
keys2 = np.array([4, 5, 6, 7])
coors1 = np.array([
[1, 1, 1],
[2, 2, 2],
[3, 3, 3]
])
coors2 = np.array([
[4, 4, 4],
[5, 5, 5],
[6, 6, 6],
[7, 7, 7]
])
target_coors = np.array([
[1, 1, 1, 4, 4, 4],
[1, 1, 1, 5, 5, 5],
[1, 1, 1, 6, 6, 6],
[1, 1, 1, 7, 7, 7],
[2, 2, 2, 4, 4, 4],
[2, 2, 2, 5, 5, 5],
[2, 2, 2, 6, 6, 6],
[2, 2, 2, 7, 7, 7],
[3, 3, 3, 4, 4, 4],
[3, 3, 3, 5, 5, 5],
[3, 3, 3, 6, 6, 6],
[3, 3, 3, 7, 7, 7]
])
target_keys = np.array([
[1, 4],
[1, 5],
[1, 6],
[1, 7],
[2, 4],
[2, 5],
[2, 6],
[2, 7],
[3, 4],
[3, 5],
[3, 6],
[3, 7]
])
new_coors, correspond_keys = cartesian_product(keys1, keys2, coors1, coors2)
assert np.array_equal(new_coors, target_coors)
assert np.array_equal(correspond_keys, target_keys)

View File

@@ -1,32 +0,0 @@
import jax.numpy as jnp
from algorithm.neat.genome.graph import topological_sort, check_cycles
from algorithm.utils import I_INT
nodes = jnp.array([
[0],
[1],
[2],
[3],
[jnp.nan]
])
# {(0, 2), (1, 2), (1, 3), (2, 3)}
conns = jnp.array([
[0, 0, 1, 0, 0],
[0, 0, 1, 1, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]
])
def test_topological_sort():
assert jnp.all(topological_sort(nodes, conns) == jnp.array([0, 1, 2, 3, I_INT]))
def test_check_cycles():
assert check_cycles(nodes, conns, 3, 2)
assert ~check_cycles(nodes, conns, 2, 3)
assert ~check_cycles(nodes, conns, 0, 3)
assert ~check_cycles(nodes, conns, 1, 0)

View File

@@ -1,33 +0,0 @@
import jax.numpy as jnp
from algorithm.utils import unflatten_connections
def test_unflatten():
nodes = jnp.array([
[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan]
])
conns = jnp.array([
[0, 1, True, 0.1, 0.11],
[0, 2, False, 0.2, 0.22],
[1, 2, True, 0.3, 0.33],
[1, 3, False, 0.4, 0.44],
])
res = unflatten_connections(nodes, conns)
assert jnp.all(res[:, 0, 1] == jnp.array([True, 0.1, 0.11]))
assert jnp.all(res[:, 0, 2] == jnp.array([False, 0.2, 0.22]))
assert jnp.all(res[:, 1, 2] == jnp.array([True, 0.3, 0.33]))
assert jnp.all(res[:, 1, 3] == jnp.array([False, 0.4, 0.44]))
# Create a mask that excludes the indices we've already checked
mask = jnp.ones(res.shape, dtype=bool)
mask = mask.at[:, [0, 0, 1, 1], [1, 2, 2, 3]].set(False)
# Ensure all other places are jnp.nan
assert jnp.all(jnp.isnan(res[mask]))