add method 'create_crossover_function' and 'create_distance_function'

This commit is contained in:
wls2002
2023-05-07 22:16:27 +08:00
parent cec40b254f
commit 47bb593a53
7 changed files with 45 additions and 15 deletions

View File

@@ -1,5 +1,5 @@
from .genome import create_initialize_function, expand, expand_single, pop_analysis
from .distance import distance
from .distance import create_distance_function
from .mutate import create_mutate_function
from .forward import create_forward_function
from .crossover import batch_crossover
from .crossover import create_crossover_function

View File

@@ -8,6 +8,13 @@ from jax import numpy as jnp
from .utils import flatten_connections, unflatten_connections
def create_crossover_function(batch: bool):
if batch:
return batch_crossover
else:
return crossover
@vmap
def batch_crossover(randkeys: Array, batch_nodes1: Array, batch_connections1: Array, batch_nodes2: Array,
batch_connections2: Array) -> Tuple[Array, Array]:
@@ -92,4 +99,4 @@ def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
only gene with the same key will be crossover, thus don't need to consider change key
"""
r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2)
return jnp.where(r > 0.5, g1, g2)

View File

@@ -4,21 +4,41 @@ from jax import numpy as jnp
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
def create_distance_function(config, type: str):
"""
:param config:
:param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation
:return:
"""
disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
compatibility_coe = config.neat.genome.compatibility_weight_coefficient
if type == 'o2o':
return lambda nodes1, connections1, nodes2, connections2: \
distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
elif type == 'o2m':
func = vmap(distance, in_axes=(None, None, 0, 0, None, None))
return lambda nodes1, connections1, batch_nodes2, batch_connections2: \
func(nodes1, connections1, batch_nodes2, batch_connections2, disjoint_coe, compatibility_coe)
else:
raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]')
@jit
def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) -> Array:
def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array, disjoint_coe: float = 1.,
compatibility_coe: float = 0.5) -> Array:
"""
Calculate the distance between two genomes.
nodes are a 2-d array with shape (N, 5), its columns are [key, bias, response, act, agg]
connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable]
"""
nd = node_distance(nodes1, nodes2) # node distance
nd = node_distance(nodes1, nodes2, disjoint_coe, compatibility_coe) # node distance
# refactor connections
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
cons1 = flatten_connections(keys1, connections1)
cons2 = flatten_connections(keys2, connections2)
cd = connection_distance(cons1, cons2) # connection distance
cd = connection_distance(cons1, cons2, disjoint_coe, compatibility_coe) # connection distance
return nd + cd