add method 'create_crossover_function' and 'create_distance_function'

This commit is contained in:
wls2002
2023-05-07 22:16:27 +08:00
parent cec40b254f
commit 47bb593a53
7 changed files with 45 additions and 15 deletions

View File

@@ -1,5 +1,5 @@
from .genome import create_initialize_function, expand, expand_single, pop_analysis from .genome import create_initialize_function, expand, expand_single, pop_analysis
from .distance import distance from .distance import create_distance_function
from .mutate import create_mutate_function from .mutate import create_mutate_function
from .forward import create_forward_function from .forward import create_forward_function
from .crossover import batch_crossover from .crossover import create_crossover_function

View File

@@ -8,6 +8,13 @@ from jax import numpy as jnp
from .utils import flatten_connections, unflatten_connections from .utils import flatten_connections, unflatten_connections
def create_crossover_function(batch: bool):
if batch:
return batch_crossover
else:
return crossover
@vmap @vmap
def batch_crossover(randkeys: Array, batch_nodes1: Array, batch_connections1: Array, batch_nodes2: Array, def batch_crossover(randkeys: Array, batch_nodes1: Array, batch_connections1: Array, batch_nodes2: Array,
batch_connections2: Array) -> Tuple[Array, Array]: batch_connections2: Array) -> Tuple[Array, Array]:
@@ -92,4 +99,4 @@ def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
only gene with the same key will be crossover, thus don't need to consider change key only gene with the same key will be crossover, thus don't need to consider change key
""" """
r = jax.random.uniform(rand_key, shape=g1.shape) r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2) return jnp.where(r > 0.5, g1, g2)

View File

@@ -4,21 +4,41 @@ from jax import numpy as jnp
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
def create_distance_function(config, type: str):
"""
:param config:
:param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation
:return:
"""
disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
compatibility_coe = config.neat.genome.compatibility_weight_coefficient
if type == 'o2o':
return lambda nodes1, connections1, nodes2, connections2: \
distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
elif type == 'o2m':
func = vmap(distance, in_axes=(None, None, 0, 0, None, None))
return lambda nodes1, connections1, batch_nodes2, batch_connections2: \
func(nodes1, connections1, batch_nodes2, batch_connections2, disjoint_coe, compatibility_coe)
else:
raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]')
@jit @jit
def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) -> Array: def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array, disjoint_coe: float = 1.,
compatibility_coe: float = 0.5) -> Array:
""" """
Calculate the distance between two genomes. Calculate the distance between two genomes.
nodes are a 2-d array with shape (N, 5), its columns are [key, bias, response, act, agg] nodes are a 2-d array with shape (N, 5), its columns are [key, bias, response, act, agg]
connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable] connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable]
""" """
nd = node_distance(nodes1, nodes2) # node distance nd = node_distance(nodes1, nodes2, disjoint_coe, compatibility_coe) # node distance
# refactor connections # refactor connections
keys1, keys2 = nodes1[:, 0], nodes2[:, 0] keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
cons1 = flatten_connections(keys1, connections1) cons1 = flatten_connections(keys1, connections1)
cons2 = flatten_connections(keys2, connections2) cons2 = flatten_connections(keys2, connections2)
cd = connection_distance(cons1, cons2) # connection distance cd = connection_distance(cons1, cons2, disjoint_coe, compatibility_coe) # connection distance
return nd + cd return nd + cd

View File

@@ -6,7 +6,7 @@ import numpy as np
from .species import SpeciesController from .species import SpeciesController
from .genome import create_initialize_function, create_mutate_function, create_forward_function from .genome import create_initialize_function, create_mutate_function, create_forward_function
from .genome import batch_crossover from .genome import create_crossover_function
from .genome import expand, expand_single from .genome import expand, expand_single
@@ -27,6 +27,7 @@ class Pipeline:
self.initialize_func = create_initialize_function(config) self.initialize_func = create_initialize_function(config)
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func() self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
self.mutate_func = create_mutate_function(config, self.input_idx, self.output_idx, batch=True) self.mutate_func = create_mutate_function(config, self.input_idx, self.output_idx, batch=True)
self.crossover_func = create_crossover_function(batch=True)
self.generation = 0 self.generation = 0
@@ -102,7 +103,7 @@ class Pipeline:
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
# mutate # mutate
mutate_rand_keys = jax.random.split(k2, self.pop_size) mutate_rand_keys = jax.random.split(k2, self.pop_size)

View File

@@ -5,7 +5,7 @@ import jax
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
from .genome import distance from .genome import create_distance_function
class Species(object): class Species(object):
@@ -47,8 +47,8 @@ class SpeciesController:
self.species_idxer = count(0) self.species_idxer = count(0)
self.species: Dict[int, Species] = {} # species_id -> species self.species: Dict[int, Species] = {} # species_id -> species
self.distance = distance self.o2o_distance = create_distance_function(self.config, type='o2o')
self.o2m_distance = jax.vmap(distance, in_axes=(None, None, 0, 0)) self.o2m_distance = create_distance_function(self.config, type='o2m')
def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None: def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None:
""" """
@@ -106,7 +106,7 @@ class SpeciesController:
# the representatives of new species # the representatives of new species
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()])) sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
distances = [ distances = [
self.distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r]) self.o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
for r in rid for r in rid
] ]
distances = np.array(distances) distances = np.array(distances)

View File

@@ -1,6 +1,7 @@
from typing import Callable, List from typing import Callable, List
from functools import partial from functools import partial
import jax
import numpy as np import numpy as np
from utils import Configer from utils import Configer
@@ -17,12 +18,13 @@ def evaluate(forward_func: Callable) -> List[float]:
:return: :return:
""" """
outs = forward_func(xor_inputs) outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses.tolist() # returns a list return fitnesses.tolist() # returns a list
# @using_cprofile @using_cprofile
@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") # @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
def main(): def main():
config = Configer.load_config() config = Configer.load_config()
pipeline = Pipeline(config, seed=11323) pipeline = Pipeline(config, seed=11323)

View File

@@ -9,7 +9,7 @@
"population": { "population": {
"fitness_criterion": "max", "fitness_criterion": "max",
"fitness_threshold": 76, "fitness_threshold": 76,
"generation_limit": 1000, "generation_limit": 100,
"pop_size": 100, "pop_size": 100,
"reset_on_extinction": "False" "reset_on_extinction": "False"
}, },