diff --git a/algorithms/neat/genome/__init__.py b/algorithms/neat/genome/__init__.py index 232ac8c..a588d97 100644 --- a/algorithms/neat/genome/__init__.py +++ b/algorithms/neat/genome/__init__.py @@ -1,5 +1,5 @@ from .genome import create_initialize_function, expand, expand_single, pop_analysis -from .distance import distance +from .distance import create_distance_function from .mutate import create_mutate_function from .forward import create_forward_function -from .crossover import batch_crossover +from .crossover import create_crossover_function diff --git a/algorithms/neat/genome/crossover.py b/algorithms/neat/genome/crossover.py index f147c58..7dca803 100644 --- a/algorithms/neat/genome/crossover.py +++ b/algorithms/neat/genome/crossover.py @@ -8,6 +8,13 @@ from jax import numpy as jnp from .utils import flatten_connections, unflatten_connections +def create_crossover_function(batch: bool): + if batch: + return batch_crossover + else: + return crossover + + @vmap def batch_crossover(randkeys: Array, batch_nodes1: Array, batch_connections1: Array, batch_nodes2: Array, batch_connections2: Array) -> Tuple[Array, Array]: @@ -92,4 +99,4 @@ def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array: only gene with the same key will be crossover, thus don't need to consider change key """ r = jax.random.uniform(rand_key, shape=g1.shape) - return jnp.where(r > 0.5, g1, g2) \ No newline at end of file + return jnp.where(r > 0.5, g1, g2) diff --git a/algorithms/neat/genome/distance.py b/algorithms/neat/genome/distance.py index 33a4aa6..c78cd7e 100644 --- a/algorithms/neat/genome/distance.py +++ b/algorithms/neat/genome/distance.py @@ -4,21 +4,41 @@ from jax import numpy as jnp from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON +def create_distance_function(config, type: str): + """ + :param config: + :param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation + :return: + """ + disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient + compatibility_coe = config.neat.genome.compatibility_weight_coefficient + if type == 'o2o': + return lambda nodes1, connections1, nodes2, connections2: \ + distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe) + elif type == 'o2m': + func = vmap(distance, in_axes=(None, None, 0, 0, None, None)) + return lambda nodes1, connections1, batch_nodes2, batch_connections2: \ + func(nodes1, connections1, batch_nodes2, batch_connections2, disjoint_coe, compatibility_coe) + else: + raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]') + + @jit -def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) -> Array: +def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array, disjoint_coe: float = 1., + compatibility_coe: float = 0.5) -> Array: """ Calculate the distance between two genomes. nodes are a 2-d array with shape (N, 5), its columns are [key, bias, response, act, agg] connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable] """ - nd = node_distance(nodes1, nodes2) # node distance + nd = node_distance(nodes1, nodes2, disjoint_coe, compatibility_coe) # node distance # refactor connections keys1, keys2 = nodes1[:, 0], nodes2[:, 0] cons1 = flatten_connections(keys1, connections1) cons2 = flatten_connections(keys2, connections2) - cd = connection_distance(cons1, cons2) # connection distance + cd = connection_distance(cons1, cons2, disjoint_coe, compatibility_coe) # connection distance return nd + cd diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index 4497bad..7223f12 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -6,7 +6,7 @@ import numpy as np from .species import SpeciesController from .genome import create_initialize_function, create_mutate_function, create_forward_function -from .genome import batch_crossover +from .genome import create_crossover_function from .genome import expand, expand_single @@ -27,6 +27,7 @@ class Pipeline: self.initialize_func = create_initialize_function(config) self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func() self.mutate_func = create_mutate_function(config, self.input_idx, self.output_idx, batch=True) + self.crossover_func = create_crossover_function(batch=True) self.generation = 0 @@ -102,7 +103,7 @@ class Pipeline: wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections - npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections + npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections # mutate mutate_rand_keys = jax.random.split(k2, self.pop_size) diff --git a/algorithms/neat/species.py b/algorithms/neat/species.py index fe48a11..75d0d0d 100644 --- a/algorithms/neat/species.py +++ b/algorithms/neat/species.py @@ -5,7 +5,7 @@ import jax import numpy as np from numpy.typing import NDArray -from .genome import distance +from .genome import create_distance_function class Species(object): @@ -47,8 +47,8 @@ class SpeciesController: self.species_idxer = count(0) self.species: Dict[int, Species] = {} # species_id -> species - self.distance = distance - self.o2m_distance = jax.vmap(distance, in_axes=(None, None, 0, 0)) + self.o2o_distance = create_distance_function(self.config, type='o2o') + self.o2m_distance = create_distance_function(self.config, type='o2m') def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None: """ @@ -106,7 +106,7 @@ class SpeciesController: # the representatives of new species sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()])) distances = [ - self.distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r]) + self.o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r]) for r in rid ] distances = np.array(distances) diff --git a/examples/xor.py b/examples/xor.py index 72212a8..e8ac80c 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -1,6 +1,7 @@ from typing import Callable, List from functools import partial +import jax import numpy as np from utils import Configer @@ -17,12 +18,13 @@ def evaluate(forward_func: Callable) -> List[float]: :return: """ outs = forward_func(xor_inputs) + outs = jax.device_get(outs) fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) return fitnesses.tolist() # returns a list -# @using_cprofile -@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") +@using_cprofile +# @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") def main(): config = Configer.load_config() pipeline = Pipeline(config, seed=11323) diff --git a/utils/default_config.json b/utils/default_config.json index 16cbc08..7b14361 100644 --- a/utils/default_config.json +++ b/utils/default_config.json @@ -9,7 +9,7 @@ "population": { "fitness_criterion": "max", "fitness_threshold": 76, - "generation_limit": 1000, + "generation_limit": 100, "pop_size": 100, "reset_on_extinction": "False" },