add method 'create_crossover_function' and 'create_distance_function'
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from .genome import create_initialize_function, expand, expand_single, pop_analysis
|
||||
from .distance import distance
|
||||
from .distance import create_distance_function
|
||||
from .mutate import create_mutate_function
|
||||
from .forward import create_forward_function
|
||||
from .crossover import batch_crossover
|
||||
from .crossover import create_crossover_function
|
||||
|
||||
@@ -8,6 +8,13 @@ from jax import numpy as jnp
|
||||
from .utils import flatten_connections, unflatten_connections
|
||||
|
||||
|
||||
def create_crossover_function(batch: bool):
|
||||
if batch:
|
||||
return batch_crossover
|
||||
else:
|
||||
return crossover
|
||||
|
||||
|
||||
@vmap
|
||||
def batch_crossover(randkeys: Array, batch_nodes1: Array, batch_connections1: Array, batch_nodes2: Array,
|
||||
batch_connections2: Array) -> Tuple[Array, Array]:
|
||||
@@ -92,4 +99,4 @@ def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
|
||||
only gene with the same key will be crossover, thus don't need to consider change key
|
||||
"""
|
||||
r = jax.random.uniform(rand_key, shape=g1.shape)
|
||||
return jnp.where(r > 0.5, g1, g2)
|
||||
return jnp.where(r > 0.5, g1, g2)
|
||||
|
||||
@@ -4,21 +4,41 @@ from jax import numpy as jnp
|
||||
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
|
||||
|
||||
|
||||
def create_distance_function(config, type: str):
|
||||
"""
|
||||
:param config:
|
||||
:param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation
|
||||
:return:
|
||||
"""
|
||||
disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
|
||||
compatibility_coe = config.neat.genome.compatibility_weight_coefficient
|
||||
if type == 'o2o':
|
||||
return lambda nodes1, connections1, nodes2, connections2: \
|
||||
distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
|
||||
elif type == 'o2m':
|
||||
func = vmap(distance, in_axes=(None, None, 0, 0, None, None))
|
||||
return lambda nodes1, connections1, batch_nodes2, batch_connections2: \
|
||||
func(nodes1, connections1, batch_nodes2, batch_connections2, disjoint_coe, compatibility_coe)
|
||||
else:
|
||||
raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]')
|
||||
|
||||
|
||||
@jit
|
||||
def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) -> Array:
|
||||
def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array, disjoint_coe: float = 1.,
|
||||
compatibility_coe: float = 0.5) -> Array:
|
||||
"""
|
||||
Calculate the distance between two genomes.
|
||||
nodes are a 2-d array with shape (N, 5), its columns are [key, bias, response, act, agg]
|
||||
connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable]
|
||||
"""
|
||||
|
||||
nd = node_distance(nodes1, nodes2) # node distance
|
||||
nd = node_distance(nodes1, nodes2, disjoint_coe, compatibility_coe) # node distance
|
||||
|
||||
# refactor connections
|
||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||
cons1 = flatten_connections(keys1, connections1)
|
||||
cons2 = flatten_connections(keys2, connections2)
|
||||
cd = connection_distance(cons1, cons2) # connection distance
|
||||
cd = connection_distance(cons1, cons2, disjoint_coe, compatibility_coe) # connection distance
|
||||
return nd + cd
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user