add method 'create_crossover_function' and 'create_distance_function'
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
from .genome import create_initialize_function, expand, expand_single, pop_analysis
|
from .genome import create_initialize_function, expand, expand_single, pop_analysis
|
||||||
from .distance import distance
|
from .distance import create_distance_function
|
||||||
from .mutate import create_mutate_function
|
from .mutate import create_mutate_function
|
||||||
from .forward import create_forward_function
|
from .forward import create_forward_function
|
||||||
from .crossover import batch_crossover
|
from .crossover import create_crossover_function
|
||||||
|
|||||||
@@ -8,6 +8,13 @@ from jax import numpy as jnp
|
|||||||
from .utils import flatten_connections, unflatten_connections
|
from .utils import flatten_connections, unflatten_connections
|
||||||
|
|
||||||
|
|
||||||
|
def create_crossover_function(batch: bool):
|
||||||
|
if batch:
|
||||||
|
return batch_crossover
|
||||||
|
else:
|
||||||
|
return crossover
|
||||||
|
|
||||||
|
|
||||||
@vmap
|
@vmap
|
||||||
def batch_crossover(randkeys: Array, batch_nodes1: Array, batch_connections1: Array, batch_nodes2: Array,
|
def batch_crossover(randkeys: Array, batch_nodes1: Array, batch_connections1: Array, batch_nodes2: Array,
|
||||||
batch_connections2: Array) -> Tuple[Array, Array]:
|
batch_connections2: Array) -> Tuple[Array, Array]:
|
||||||
@@ -92,4 +99,4 @@ def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
|
|||||||
only gene with the same key will be crossover, thus don't need to consider change key
|
only gene with the same key will be crossover, thus don't need to consider change key
|
||||||
"""
|
"""
|
||||||
r = jax.random.uniform(rand_key, shape=g1.shape)
|
r = jax.random.uniform(rand_key, shape=g1.shape)
|
||||||
return jnp.where(r > 0.5, g1, g2)
|
return jnp.where(r > 0.5, g1, g2)
|
||||||
|
|||||||
@@ -4,21 +4,41 @@ from jax import numpy as jnp
|
|||||||
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
|
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
|
||||||
|
|
||||||
|
|
||||||
|
def create_distance_function(config, type: str):
|
||||||
|
"""
|
||||||
|
:param config:
|
||||||
|
:param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
|
||||||
|
compatibility_coe = config.neat.genome.compatibility_weight_coefficient
|
||||||
|
if type == 'o2o':
|
||||||
|
return lambda nodes1, connections1, nodes2, connections2: \
|
||||||
|
distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
|
||||||
|
elif type == 'o2m':
|
||||||
|
func = vmap(distance, in_axes=(None, None, 0, 0, None, None))
|
||||||
|
return lambda nodes1, connections1, batch_nodes2, batch_connections2: \
|
||||||
|
func(nodes1, connections1, batch_nodes2, batch_connections2, disjoint_coe, compatibility_coe)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]')
|
||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) -> Array:
|
def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array, disjoint_coe: float = 1.,
|
||||||
|
compatibility_coe: float = 0.5) -> Array:
|
||||||
"""
|
"""
|
||||||
Calculate the distance between two genomes.
|
Calculate the distance between two genomes.
|
||||||
nodes are a 2-d array with shape (N, 5), its columns are [key, bias, response, act, agg]
|
nodes are a 2-d array with shape (N, 5), its columns are [key, bias, response, act, agg]
|
||||||
connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable]
|
connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
nd = node_distance(nodes1, nodes2) # node distance
|
nd = node_distance(nodes1, nodes2, disjoint_coe, compatibility_coe) # node distance
|
||||||
|
|
||||||
# refactor connections
|
# refactor connections
|
||||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||||
cons1 = flatten_connections(keys1, connections1)
|
cons1 = flatten_connections(keys1, connections1)
|
||||||
cons2 = flatten_connections(keys2, connections2)
|
cons2 = flatten_connections(keys2, connections2)
|
||||||
cd = connection_distance(cons1, cons2) # connection distance
|
cd = connection_distance(cons1, cons2, disjoint_coe, compatibility_coe) # connection distance
|
||||||
return nd + cd
|
return nd + cd
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import numpy as np
|
|||||||
|
|
||||||
from .species import SpeciesController
|
from .species import SpeciesController
|
||||||
from .genome import create_initialize_function, create_mutate_function, create_forward_function
|
from .genome import create_initialize_function, create_mutate_function, create_forward_function
|
||||||
from .genome import batch_crossover
|
from .genome import create_crossover_function
|
||||||
from .genome import expand, expand_single
|
from .genome import expand, expand_single
|
||||||
|
|
||||||
|
|
||||||
@@ -27,6 +27,7 @@ class Pipeline:
|
|||||||
self.initialize_func = create_initialize_function(config)
|
self.initialize_func = create_initialize_function(config)
|
||||||
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
|
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
|
||||||
self.mutate_func = create_mutate_function(config, self.input_idx, self.output_idx, batch=True)
|
self.mutate_func = create_mutate_function(config, self.input_idx, self.output_idx, batch=True)
|
||||||
|
self.crossover_func = create_crossover_function(batch=True)
|
||||||
|
|
||||||
self.generation = 0
|
self.generation = 0
|
||||||
|
|
||||||
@@ -102,7 +103,7 @@ class Pipeline:
|
|||||||
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
|
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
|
||||||
lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
|
lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
|
||||||
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
|
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
|
||||||
npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
|
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
|
||||||
|
|
||||||
# mutate
|
# mutate
|
||||||
mutate_rand_keys = jax.random.split(k2, self.pop_size)
|
mutate_rand_keys = jax.random.split(k2, self.pop_size)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import jax
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from .genome import distance
|
from .genome import create_distance_function
|
||||||
|
|
||||||
|
|
||||||
class Species(object):
|
class Species(object):
|
||||||
@@ -47,8 +47,8 @@ class SpeciesController:
|
|||||||
self.species_idxer = count(0)
|
self.species_idxer = count(0)
|
||||||
self.species: Dict[int, Species] = {} # species_id -> species
|
self.species: Dict[int, Species] = {} # species_id -> species
|
||||||
|
|
||||||
self.distance = distance
|
self.o2o_distance = create_distance_function(self.config, type='o2o')
|
||||||
self.o2m_distance = jax.vmap(distance, in_axes=(None, None, 0, 0))
|
self.o2m_distance = create_distance_function(self.config, type='o2m')
|
||||||
|
|
||||||
def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None:
|
def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -106,7 +106,7 @@ class SpeciesController:
|
|||||||
# the representatives of new species
|
# the representatives of new species
|
||||||
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
|
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
|
||||||
distances = [
|
distances = [
|
||||||
self.distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
|
self.o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
|
||||||
for r in rid
|
for r in rid
|
||||||
]
|
]
|
||||||
distances = np.array(distances)
|
distances = np.array(distances)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from typing import Callable, List
|
from typing import Callable, List
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
import jax
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from utils import Configer
|
from utils import Configer
|
||||||
@@ -17,12 +18,13 @@ def evaluate(forward_func: Callable) -> List[float]:
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
outs = forward_func(xor_inputs)
|
outs = forward_func(xor_inputs)
|
||||||
|
outs = jax.device_get(outs)
|
||||||
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||||
return fitnesses.tolist() # returns a list
|
return fitnesses.tolist() # returns a list
|
||||||
|
|
||||||
|
|
||||||
# @using_cprofile
|
@using_cprofile
|
||||||
@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
|
# @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
|
||||||
def main():
|
def main():
|
||||||
config = Configer.load_config()
|
config = Configer.load_config()
|
||||||
pipeline = Pipeline(config, seed=11323)
|
pipeline = Pipeline(config, seed=11323)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
"population": {
|
"population": {
|
||||||
"fitness_criterion": "max",
|
"fitness_criterion": "max",
|
||||||
"fitness_threshold": 76,
|
"fitness_threshold": 76,
|
||||||
"generation_limit": 1000,
|
"generation_limit": 100,
|
||||||
"pop_size": 100,
|
"pop_size": 100,
|
||||||
"reset_on_extinction": "False"
|
"reset_on_extinction": "False"
|
||||||
},
|
},
|
||||||
|
|||||||
Reference in New Issue
Block a user