add method 'create_crossover_function' and 'create_distance_function'

This commit is contained in:
wls2002
2023-05-07 22:16:27 +08:00
parent cec40b254f
commit 47bb593a53
7 changed files with 45 additions and 15 deletions

View File

@@ -1,5 +1,5 @@
from .genome import create_initialize_function, expand, expand_single, pop_analysis
from .distance import distance
from .distance import create_distance_function
from .mutate import create_mutate_function
from .forward import create_forward_function
from .crossover import batch_crossover
from .crossover import create_crossover_function

View File

@@ -8,6 +8,13 @@ from jax import numpy as jnp
from .utils import flatten_connections, unflatten_connections
def create_crossover_function(batch: bool):
if batch:
return batch_crossover
else:
return crossover
@vmap
def batch_crossover(randkeys: Array, batch_nodes1: Array, batch_connections1: Array, batch_nodes2: Array,
batch_connections2: Array) -> Tuple[Array, Array]:
@@ -92,4 +99,4 @@ def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
only gene with the same key will be crossover, thus don't need to consider change key
"""
r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2)
return jnp.where(r > 0.5, g1, g2)

View File

@@ -4,21 +4,41 @@ from jax import numpy as jnp
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
def create_distance_function(config, type: str):
"""
:param config:
:param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation
:return:
"""
disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
compatibility_coe = config.neat.genome.compatibility_weight_coefficient
if type == 'o2o':
return lambda nodes1, connections1, nodes2, connections2: \
distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
elif type == 'o2m':
func = vmap(distance, in_axes=(None, None, 0, 0, None, None))
return lambda nodes1, connections1, batch_nodes2, batch_connections2: \
func(nodes1, connections1, batch_nodes2, batch_connections2, disjoint_coe, compatibility_coe)
else:
raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]')
@jit
def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) -> Array:
def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array, disjoint_coe: float = 1.,
compatibility_coe: float = 0.5) -> Array:
"""
Calculate the distance between two genomes.
nodes are a 2-d array with shape (N, 5), its columns are [key, bias, response, act, agg]
connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable]
"""
nd = node_distance(nodes1, nodes2) # node distance
nd = node_distance(nodes1, nodes2, disjoint_coe, compatibility_coe) # node distance
# refactor connections
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
cons1 = flatten_connections(keys1, connections1)
cons2 = flatten_connections(keys2, connections2)
cd = connection_distance(cons1, cons2) # connection distance
cd = connection_distance(cons1, cons2, disjoint_coe, compatibility_coe) # connection distance
return nd + cd

View File

@@ -6,7 +6,7 @@ import numpy as np
from .species import SpeciesController
from .genome import create_initialize_function, create_mutate_function, create_forward_function
from .genome import batch_crossover
from .genome import create_crossover_function
from .genome import expand, expand_single
@@ -27,6 +27,7 @@ class Pipeline:
self.initialize_func = create_initialize_function(config)
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
self.mutate_func = create_mutate_function(config, self.input_idx, self.output_idx, batch=True)
self.crossover_func = create_crossover_function(batch=True)
self.generation = 0
@@ -102,7 +103,7 @@ class Pipeline:
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
# mutate
mutate_rand_keys = jax.random.split(k2, self.pop_size)

View File

@@ -5,7 +5,7 @@ import jax
import numpy as np
from numpy.typing import NDArray
from .genome import distance
from .genome import create_distance_function
class Species(object):
@@ -47,8 +47,8 @@ class SpeciesController:
self.species_idxer = count(0)
self.species: Dict[int, Species] = {} # species_id -> species
self.distance = distance
self.o2m_distance = jax.vmap(distance, in_axes=(None, None, 0, 0))
self.o2o_distance = create_distance_function(self.config, type='o2o')
self.o2m_distance = create_distance_function(self.config, type='o2m')
def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None:
"""
@@ -106,7 +106,7 @@ class SpeciesController:
# the representatives of new species
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
distances = [
self.distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
self.o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
for r in rid
]
distances = np.array(distances)

View File

@@ -1,6 +1,7 @@
from typing import Callable, List
from functools import partial
import jax
import numpy as np
from utils import Configer
@@ -17,12 +18,13 @@ def evaluate(forward_func: Callable) -> List[float]:
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses.tolist() # returns a list
# @using_cprofile
@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
@using_cprofile
# @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
def main():
config = Configer.load_config()
pipeline = Pipeline(config, seed=11323)

View File

@@ -9,7 +9,7 @@
"population": {
"fitness_criterion": "max",
"fitness_threshold": 76,
"generation_limit": 1000,
"generation_limit": 100,
"pop_size": 100,
"reset_on_extinction": "False"
},