bug down! Here it can solve xor successfully!

This commit is contained in:
wls2002
2023-05-07 16:03:52 +08:00
parent d1f54022bd
commit a3b9bca866
12 changed files with 120 additions and 254 deletions

View File

@@ -5,8 +5,7 @@ import jax
from jax import jit, vmap, Array
from jax import numpy as jnp
# from .utils import flatten_connections, unflatten_connections
from algorithms.neat.genome.utils import flatten_connections, unflatten_connections
from .utils import flatten_connections, unflatten_connections
@vmap
@@ -93,59 +92,4 @@ def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
only gene with the same key will be crossover, thus don't need to consider change key
"""
r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2)
if __name__ == '__main__':
import numpy as np
randkey = jax.random.PRNGKey(40)
nodes1 = np.array([
[4, 1, 1, 1, 1],
[6, 2, 2, 2, 2],
[1, 3, 3, 3, 3],
[5, 4, 4, 4, 4],
[np.nan, np.nan, np.nan, np.nan, np.nan]
])
nodes2 = np.array([
[4, 1.5, 1.5, 1.5, 1.5],
[7, 3.5, 3.5, 3.5, 3.5],
[5, 4.5, 4.5, 4.5, 4.5],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
])
weights1 = np.array([
[
[1, 2, 3, 4., np.nan],
[5, np.nan, 7, 8, np.nan],
[9, 10, 11, 12, np.nan],
[13, 14, 15, 16, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
],
[
[0, 1, 0, 1, np.nan],
[0, np.nan, 0, 1, np.nan],
[0, 1, 0, 1, np.nan],
[0, 1, 0, 1, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
]
])
weights2 = np.array([
[
[1.5, 2.5, 3.5, np.nan, np.nan],
[3.5, 4.5, 5.5, np.nan, np.nan],
[6.5, 7.5, 8.5, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
],
[
[1, 0, 1, np.nan, np.nan],
[1, 0, 1, np.nan, np.nan],
[1, 0, 1, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
]
])
res = crossover(randkey, nodes1, weights1, nodes2, weights2)
print(*res, sep='\n')
return jnp.where(r > 0.5, g1, g2)