add method 'create_crossover_function' and 'create_distance_function'
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from .genome import create_initialize_function, expand, expand_single, pop_analysis
|
||||
from .distance import distance
|
||||
from .distance import create_distance_function
|
||||
from .mutate import create_mutate_function
|
||||
from .forward import create_forward_function
|
||||
from .crossover import batch_crossover
|
||||
from .crossover import create_crossover_function
|
||||
|
||||
@@ -8,6 +8,13 @@ from jax import numpy as jnp
|
||||
from .utils import flatten_connections, unflatten_connections
|
||||
|
||||
|
||||
def create_crossover_function(batch: bool):
|
||||
if batch:
|
||||
return batch_crossover
|
||||
else:
|
||||
return crossover
|
||||
|
||||
|
||||
@vmap
|
||||
def batch_crossover(randkeys: Array, batch_nodes1: Array, batch_connections1: Array, batch_nodes2: Array,
|
||||
batch_connections2: Array) -> Tuple[Array, Array]:
|
||||
|
||||
@@ -4,21 +4,41 @@ from jax import numpy as jnp
|
||||
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
|
||||
|
||||
|
||||
def create_distance_function(config, type: str):
|
||||
"""
|
||||
:param config:
|
||||
:param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation
|
||||
:return:
|
||||
"""
|
||||
disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
|
||||
compatibility_coe = config.neat.genome.compatibility_weight_coefficient
|
||||
if type == 'o2o':
|
||||
return lambda nodes1, connections1, nodes2, connections2: \
|
||||
distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
|
||||
elif type == 'o2m':
|
||||
func = vmap(distance, in_axes=(None, None, 0, 0, None, None))
|
||||
return lambda nodes1, connections1, batch_nodes2, batch_connections2: \
|
||||
func(nodes1, connections1, batch_nodes2, batch_connections2, disjoint_coe, compatibility_coe)
|
||||
else:
|
||||
raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]')
|
||||
|
||||
|
||||
@jit
|
||||
def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) -> Array:
|
||||
def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array, disjoint_coe: float = 1.,
|
||||
compatibility_coe: float = 0.5) -> Array:
|
||||
"""
|
||||
Calculate the distance between two genomes.
|
||||
nodes are a 2-d array with shape (N, 5), its columns are [key, bias, response, act, agg]
|
||||
connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable]
|
||||
"""
|
||||
|
||||
nd = node_distance(nodes1, nodes2) # node distance
|
||||
nd = node_distance(nodes1, nodes2, disjoint_coe, compatibility_coe) # node distance
|
||||
|
||||
# refactor connections
|
||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||
cons1 = flatten_connections(keys1, connections1)
|
||||
cons2 = flatten_connections(keys2, connections2)
|
||||
cd = connection_distance(cons1, cons2) # connection distance
|
||||
cd = connection_distance(cons1, cons2, disjoint_coe, compatibility_coe) # connection distance
|
||||
return nd + cd
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import numpy as np
|
||||
|
||||
from .species import SpeciesController
|
||||
from .genome import create_initialize_function, create_mutate_function, create_forward_function
|
||||
from .genome import batch_crossover
|
||||
from .genome import create_crossover_function
|
||||
from .genome import expand, expand_single
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ class Pipeline:
|
||||
self.initialize_func = create_initialize_function(config)
|
||||
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
|
||||
self.mutate_func = create_mutate_function(config, self.input_idx, self.output_idx, batch=True)
|
||||
self.crossover_func = create_crossover_function(batch=True)
|
||||
|
||||
self.generation = 0
|
||||
|
||||
@@ -102,7 +103,7 @@ class Pipeline:
|
||||
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
|
||||
lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
|
||||
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
|
||||
npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
|
||||
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
|
||||
|
||||
# mutate
|
||||
mutate_rand_keys = jax.random.split(k2, self.pop_size)
|
||||
|
||||
@@ -5,7 +5,7 @@ import jax
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from .genome import distance
|
||||
from .genome import create_distance_function
|
||||
|
||||
|
||||
class Species(object):
|
||||
@@ -47,8 +47,8 @@ class SpeciesController:
|
||||
self.species_idxer = count(0)
|
||||
self.species: Dict[int, Species] = {} # species_id -> species
|
||||
|
||||
self.distance = distance
|
||||
self.o2m_distance = jax.vmap(distance, in_axes=(None, None, 0, 0))
|
||||
self.o2o_distance = create_distance_function(self.config, type='o2o')
|
||||
self.o2m_distance = create_distance_function(self.config, type='o2m')
|
||||
|
||||
def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None:
|
||||
"""
|
||||
@@ -106,7 +106,7 @@ class SpeciesController:
|
||||
# the representatives of new species
|
||||
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
|
||||
distances = [
|
||||
self.distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
|
||||
self.o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
|
||||
for r in rid
|
||||
]
|
||||
distances = np.array(distances)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Callable, List
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from utils import Configer
|
||||
@@ -17,12 +18,13 @@ def evaluate(forward_func: Callable) -> List[float]:
|
||||
:return:
|
||||
"""
|
||||
outs = forward_func(xor_inputs)
|
||||
outs = jax.device_get(outs)
|
||||
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||
return fitnesses.tolist() # returns a list
|
||||
|
||||
|
||||
# @using_cprofile
|
||||
@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
|
||||
@using_cprofile
|
||||
# @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
|
||||
def main():
|
||||
config = Configer.load_config()
|
||||
pipeline = Pipeline(config, seed=11323)
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
"population": {
|
||||
"fitness_criterion": "max",
|
||||
"fitness_threshold": 76,
|
||||
"generation_limit": 1000,
|
||||
"generation_limit": 100,
|
||||
"pop_size": 100,
|
||||
"reset_on_extinction": "False"
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user