add method 'create_crossover_function' and 'create_distance_function'
This commit is contained in:
@@ -8,6 +8,13 @@ from jax import numpy as jnp
|
||||
from .utils import flatten_connections, unflatten_connections
|
||||
|
||||
|
||||
def create_crossover_function(batch: bool):
|
||||
if batch:
|
||||
return batch_crossover
|
||||
else:
|
||||
return crossover
|
||||
|
||||
|
||||
@vmap
|
||||
def batch_crossover(randkeys: Array, batch_nodes1: Array, batch_connections1: Array, batch_nodes2: Array,
|
||||
batch_connections2: Array) -> Tuple[Array, Array]:
|
||||
@@ -92,4 +99,4 @@ def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
|
||||
only gene with the same key will be crossover, thus don't need to consider change key
|
||||
"""
|
||||
r = jax.random.uniform(rand_key, shape=g1.shape)
|
||||
return jnp.where(r > 0.5, g1, g2)
|
||||
return jnp.where(r > 0.5, g1, g2)
|
||||
|
||||
Reference in New Issue
Block a user