bug down! Here it can solve xor successfully!
This commit is contained in:
@@ -5,8 +5,7 @@ import jax
|
||||
from jax import jit, vmap, Array
|
||||
from jax import numpy as jnp
|
||||
|
||||
# from .utils import flatten_connections, unflatten_connections
|
||||
from algorithms.neat.genome.utils import flatten_connections, unflatten_connections
|
||||
from .utils import flatten_connections, unflatten_connections
|
||||
|
||||
|
||||
@vmap
|
||||
@@ -93,59 +92,4 @@ def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
|
||||
only gene with the same key will be crossover, thus don't need to consider change key
|
||||
"""
|
||||
r = jax.random.uniform(rand_key, shape=g1.shape)
|
||||
return jnp.where(r > 0.5, g1, g2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import numpy as np
|
||||
|
||||
randkey = jax.random.PRNGKey(40)
|
||||
nodes1 = np.array([
|
||||
[4, 1, 1, 1, 1],
|
||||
[6, 2, 2, 2, 2],
|
||||
[1, 3, 3, 3, 3],
|
||||
[5, 4, 4, 4, 4],
|
||||
[np.nan, np.nan, np.nan, np.nan, np.nan]
|
||||
])
|
||||
nodes2 = np.array([
|
||||
[4, 1.5, 1.5, 1.5, 1.5],
|
||||
[7, 3.5, 3.5, 3.5, 3.5],
|
||||
[5, 4.5, 4.5, 4.5, 4.5],
|
||||
[np.nan, np.nan, np.nan, np.nan, np.nan],
|
||||
[np.nan, np.nan, np.nan, np.nan, np.nan],
|
||||
])
|
||||
weights1 = np.array([
|
||||
[
|
||||
[1, 2, 3, 4., np.nan],
|
||||
[5, np.nan, 7, 8, np.nan],
|
||||
[9, 10, 11, 12, np.nan],
|
||||
[13, 14, 15, 16, np.nan],
|
||||
[np.nan, np.nan, np.nan, np.nan, np.nan]
|
||||
],
|
||||
[
|
||||
[0, 1, 0, 1, np.nan],
|
||||
[0, np.nan, 0, 1, np.nan],
|
||||
[0, 1, 0, 1, np.nan],
|
||||
[0, 1, 0, 1, np.nan],
|
||||
[np.nan, np.nan, np.nan, np.nan, np.nan]
|
||||
]
|
||||
])
|
||||
weights2 = np.array([
|
||||
[
|
||||
[1.5, 2.5, 3.5, np.nan, np.nan],
|
||||
[3.5, 4.5, 5.5, np.nan, np.nan],
|
||||
[6.5, 7.5, 8.5, np.nan, np.nan],
|
||||
[np.nan, np.nan, np.nan, np.nan, np.nan],
|
||||
[np.nan, np.nan, np.nan, np.nan, np.nan]
|
||||
],
|
||||
[
|
||||
[1, 0, 1, np.nan, np.nan],
|
||||
[1, 0, 1, np.nan, np.nan],
|
||||
[1, 0, 1, np.nan, np.nan],
|
||||
[np.nan, np.nan, np.nan, np.nan, np.nan],
|
||||
[np.nan, np.nan, np.nan, np.nan, np.nan]
|
||||
]
|
||||
])
|
||||
|
||||
res = crossover(randkey, nodes1, weights1, nodes2, weights2)
|
||||
print(*res, sep='\n')
|
||||
return jnp.where(r > 0.5, g1, g2)
|
||||
Reference in New Issue
Block a user