generally complete, but not work well. Debug

This commit is contained in:
wls2002
2023-05-06 11:35:44 +08:00
parent 8f780b63d6
commit 73ac1bcfe0
8 changed files with 233 additions and 84 deletions

View File

@@ -1,4 +1,5 @@
from .genome import create_initialize_function from .genome import create_initialize_function, expand, expand_single
from .distance import distance from .distance import distance
from .mutate import create_mutate_function from .mutate import create_mutate_function
from .forward import create_forward_function from .forward import create_forward_function
from .crossover import batch_crossover

View File

@@ -99,8 +99,8 @@ def initialize_genomes(pop_size: int,
def expand(pop_nodes: NDArray, pop_connections: NDArray, new_N: int) -> Tuple[NDArray, NDArray]: def expand(pop_nodes: NDArray, pop_connections: NDArray, new_N: int) -> Tuple[NDArray, NDArray]:
""" """
Expand the genome to accommodate more nodes. Expand the genome to accommodate more nodes.
:param pop_nodes: :param pop_nodes: (pop_size, N, 5)
:param pop_connections: :param pop_connections: (pop_size, 2, N, N)
:param new_N: :param new_N:
:return: :return:
""" """
@@ -114,6 +114,23 @@ def expand(pop_nodes: NDArray, pop_connections: NDArray, new_N: int) -> Tuple[ND
return new_pop_nodes, new_pop_connections return new_pop_nodes, new_pop_connections
def expand_single(nodes: NDArray, connections: NDArray, new_N: int) -> Tuple[NDArray, NDArray]:
"""
Expand a single genome to accommodate more nodes.
:param nodes: (N, 5)
:param connections: (2, N, N)
:param new_N:
:return:
"""
old_N = nodes.shape[0]
new_nodes = np.full((new_N, 5), np.nan)
new_nodes[:old_N, :] = nodes
new_connections = np.full((2, new_N, new_N), np.nan)
new_connections[:, :old_N, :old_N] = connections
return new_nodes, new_connections
@jit @jit
def add_node(new_node_key: int, nodes: Array, connections: Array, def add_node(new_node_key: int, nodes: Array, connections: Array,
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]: bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]:

View File

@@ -1,7 +1,13 @@
from typing import List, Union, Tuple, Callable
import time
import jax import jax
import numpy as np
from .species import SpeciesController from .species import SpeciesController
from .genome import create_initialize_function, create_mutate_function, create_forward_function from .genome import create_initialize_function, create_mutate_function, create_forward_function
from .genome import batch_crossover
from .genome import expand, expand_single
class Pipeline: class Pipeline:
@@ -9,9 +15,13 @@ class Pipeline:
Neat algorithm pipeline. Neat algorithm pipeline.
""" """
def __init__(self, config): def __init__(self, config, seed=42):
self.randkey = jax.random.PRNGKey(seed)
self.config = config self.config = config
self.N = config.basic.init_maximum_nodes self.N = config.basic.init_maximum_nodes
self.expand_coe = config.basic.expands_coe
self.pop_size = config.neat.population.pop_size
self.species_controller = SpeciesController(config) self.species_controller = SpeciesController(config)
self.initialize_func = create_initialize_function(config) self.initialize_func = create_initialize_function(config)
@@ -22,6 +32,11 @@ class Pipeline:
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation) self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation)
self.new_node_keys_pool: List[int] = [max(self.output_idx) + 1]
self.generation_timestamp = time.time()
self.best_fitness = float('-inf')
def ask(self, batch: bool): def ask(self, batch: bool):
""" """
Create a forward function for the population. Create a forward function for the population.
@@ -35,7 +50,120 @@ class Pipeline:
def tell(self, fitnesses): def tell(self, fitnesses):
self.generation += 1 self.generation += 1
print(type(fitnesses), fitnesses)
self.species_controller.update_species_fitnesses(fitnesses) self.species_controller.update_species_fitnesses(fitnesses)
crossover_pair = self.species_controller.reproduce(self.generation)
self.update_next_generation(crossover_pair)
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation)
self.expand()
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config.neat.population.generation_limit):
forward_func = self.ask(batch=True)
fitnesses = fitness_func(forward_func)
if analysis is not None:
if analysis == "default":
self.default_analysis(fitnesses)
else:
assert callable(analysis), f"What the fuck you passed in? A {analysis}?"
analysis(fitnesses)
self.tell(fitnesses)
print("Generation limit reached!")
def update_next_generation(self, crossover_pair: List[Union[int, Tuple[int, int]]]) -> None:
"""
create the next generation
:param crossover_pair: created from self.reproduce()
"""
assert self.pop_nodes.shape[0] == self.pop_size
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
# crossover
# prepare elitism mask and crossover pair
elitism_mask = np.full(self.pop_size, False)
for i, pair in enumerate(crossover_pair):
if not isinstance(pair, tuple): # elitism
elitism_mask[i] = True
crossover_pair[i] = (pair, pair)
crossover_pair = np.array(crossover_pair)
# batch crossover
wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
crossover_rand_keys = jax.random.split(k1, self.pop_size)
# npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
npn, npc = crossover_wrapper(crossover_rand_keys, wpn, wpc, lpn, lpc)
# mutate
new_node_keys = np.array(self.fetch_new_node_keys())
mutate_rand_keys = jax.random.split(k2, self.pop_size)
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys)
m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc)
# elitism don't mutate
# (pop_size, ) to (pop_size, 1, 1)
self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn)
# (pop_size, ) to (pop_size, 1, 1, 1)
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
# recycle unused node keys
unused = []
for i, nodes in enumerate(self.pop_nodes):
node_keys, key = nodes[:, 0], new_node_keys[i]
if not np.isin(key, node_keys): # the new node key is not used
unused.append(key)
self.new_node_keys_pool = unused + self.new_node_keys_pool
def expand(self):
"""
Expand the population if needed.
:return:
when the maximum node number of the population >= N
the population will expand
"""
pop_node_keys = self.pop_nodes[:, :, 0]
pop_node_sizes = np.sum(~np.isnan(pop_node_keys), axis=1)
max_node_size = np.max(pop_node_sizes)
if max_node_size >= self.N:
print(f"expand to {self.N}!")
self.N = int(self.N * self.expand_coe)
self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N)
# don't forget to expand representation genome in species
for s in self.species_controller.species:
s.representative = expand(*s.representative, self.N)
def fetch_new_node_keys(self):
# if remain unused keys are not enough, create new keys
if len(self.new_node_keys_pool) < self.pop_size:
max_unused_key = max(self.new_node_keys_pool) if self.new_node_keys_pool else -1
new_keys = list(range(max_unused_key + 1, max_unused_key + 1 + 10 * self.pop_size))
self.new_node_keys_pool.extend(new_keys)
# fetch keys from pool
res = self.new_node_keys_pool[:self.pop_size]
self.new_node_keys_pool = self.new_node_keys_pool[self.pop_size:]
return res
def default_analysis(self, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
species_sizes = [len(s.members) for s in self.species_controller.species.values()]
new_timestamp = time.time()
cost_time = new_timestamp - self.generation_timestamp
self.generation_timestamp = new_timestamp
print(f"Generation: {self.generation}",
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")
def crossover_wrapper(*args):
return batch_crossover(*args)

View File

@@ -1,4 +1,4 @@
from typing import List, Tuple, Dict, Optional from typing import List, Tuple, Dict, Union
from itertools import count from itertools import count
import jax import jax
@@ -45,7 +45,6 @@ class SpeciesController:
self.species_idxer = count(0) self.species_idxer = count(0)
self.species: Dict[int, Species] = {} # species_id -> species self.species: Dict[int, Species] = {} # species_id -> species
self.genome_to_species: Dict[int, int] = {}
self.o2m_distance_func = jax.vmap(distance, in_axes=(None, None, 0, 0)) # one to many self.o2m_distance_func = jax.vmap(distance, in_axes=(None, None, 0, 0)) # one to many
# self.o2o_distance_func = np_distance # one to one # self.o2o_distance_func = np_distance # one to one
@@ -79,15 +78,15 @@ class SpeciesController:
# Partition population into species based on genetic similarity. # Partition population into species based on genetic similarity.
# First, fast match the population to previous species # First, fast match the population to previous species
if previous_species_list: # exist previous species
rid_list = [new_representatives[sid] for sid in previous_species_list] rid_list = [new_representatives[sid] for sid in previous_species_list]
res_pop_distance = [ res_pop_distance = [
jax.device_get( jax.device_get(
[
self.o2m_distance_func(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections) self.o2m_distance_func(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
)
for rid in rid_list for rid in rid_list
] ]
)
]
pop_res_distance = np.stack(res_pop_distance, axis=0).T pop_res_distance = np.stack(res_pop_distance, axis=0).T
for i in range(pop_res_distance.shape[0]): for i in range(pop_res_distance.shape[0]):
if not unspeciated[i]: if not unspeciated[i]:
@@ -102,13 +101,14 @@ class SpeciesController:
# Second, slowly match the lonely population to new-created species. # Second, slowly match the lonely population to new-created species.
# lonely genome is proved to be not compatible with any previous species, so they only need to be compared with # lonely genome is proved to be not compatible with any previous species, so they only need to be compared with
# the new representatives. # the new representatives.
new_species_list = []
for i in range(pop_nodes.shape[0]): for i in range(pop_nodes.shape[0]):
if not unspeciated[i]: if not unspeciated[i]:
continue continue
unspeciated[i] = False unspeciated[i] = False
if len(new_representatives) != 0: if len(new_representatives) != 0:
rid = [new_representatives[sid] for sid in new_representatives] # the representatives of new species # the representatives of new species
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
distances = [ distances = [
self.o2o_distance_func(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r]) self.o2o_distance_func(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
for r in rid for r in rid
@@ -117,18 +117,17 @@ class SpeciesController:
min_idx = np.argmin(distances) min_idx = np.argmin(distances)
min_val = distances[min_idx] min_val = distances[min_idx]
if min_val <= self.compatibility_threshold: if min_val <= self.compatibility_threshold:
species_id = new_species_list[min_idx] species_id = sid[min_idx]
new_members[species_id].append(i) new_members[species_id].append(i)
continue continue
# create a new species # create a new species
species_id = next(self.species_idxer) species_id = next(self.species_idxer)
new_species_list.append(species_id)
new_representatives[species_id] = i new_representatives[species_id] = i
new_members[species_id] = [i] new_members[species_id] = [i]
assert np.all(~unspeciated) assert np.all(~unspeciated)
# Update species collection based on new speciation. # Update species collection based on new speciation.
self.genome_to_species = {}
for sid, rid in new_representatives.items(): for sid, rid in new_representatives.items():
s = self.species.get(sid) s = self.species.get(sid)
if s is None: if s is None:
@@ -136,12 +135,7 @@ class SpeciesController:
self.species[sid] = s self.species[sid] = s
members = new_members[sid] members = new_members[sid]
for gid in members:
self.genome_to_species[gid] = sid
s.update((pop_nodes[rid], pop_connections[rid]), members) s.update((pop_nodes[rid], pop_connections[rid]), members)
for s in self.species.values():
print(s.members)
def update_species_fitnesses(self, fitnesses): def update_species_fitnesses(self, fitnesses):
""" """
@@ -189,11 +183,11 @@ class SpeciesController:
result.append((sid, s, is_stagnant)) result.append((sid, s, is_stagnant))
return result return result
def reproduce(self, generation: int) -> List[Optional[int, Tuple[int, int]]]: def reproduce(self, generation: int) -> List[Union[int, Tuple[int, int]]]:
""" """
code modified from neat-python! code modified from neat-python!
:param generation: :param generation:
:return: next population indices. :return: crossover_pair for next generation.
# int -> idx in the pop_nodes, pop_connections of elitism # int -> idx in the pop_nodes, pop_connections of elitism
# (int, int) -> the father and mother idx to be crossover # (int, int) -> the father and mother idx to be crossover
""" """
@@ -235,7 +229,7 @@ class SpeciesController:
self.species = {} self.species = {}
# int -> idx in the pop_nodes, pop_connections of elitism # int -> idx in the pop_nodes, pop_connections of elitism
# (int, int) -> the father and mother idx to be crossover # (int, int) -> the father and mother idx to be crossover
new_population: List[Optional[int, Tuple[int, int]]] = [] crossover_pair: List[Union[int, Tuple[int, int]]] = []
for spawn, s in zip(spawn_amounts, remaining_species): for spawn, s in zip(spawn_amounts, remaining_species):
assert spawn >= self.genome_elitism assert spawn >= self.genome_elitism
@@ -248,7 +242,7 @@ class SpeciesController:
sorted_members, sorted_fitnesses = sort_element_with_fitnesses(old_members, fitnesses) sorted_members, sorted_fitnesses = sort_element_with_fitnesses(old_members, fitnesses)
if self.genome_elitism > 0: if self.genome_elitism > 0:
for m in sorted_members[:self.genome_elitism]: for m in sorted_members[:self.genome_elitism]:
new_population.append(m) crossover_pair.append(m)
spawn -= 1 spawn -= 1
if spawn <= 0: if spawn <= 0:
@@ -262,16 +256,16 @@ class SpeciesController:
# Randomly choose parents and produce the number of offspring allotted to the species. # Randomly choose parents and produce the number of offspring allotted to the species.
for _ in range(spawn): for _ in range(spawn):
assert len(sorted_members) >= 2 # allow to replace, for the case that the species only has one genome
c1, c2 = np.random.choice(len(sorted_members), size=2, replace=False) c1, c2 = np.random.choice(len(sorted_members), size=2, replace=True)
idx1, fitness1 = sorted_members[c1], sorted_fitnesses[c1] idx1, fitness1 = sorted_members[c1], sorted_fitnesses[c1]
idx2, fitness2 = sorted_members[c2], sorted_fitnesses[c2] idx2, fitness2 = sorted_members[c2], sorted_fitnesses[c2]
if fitness1 >= fitness2: if fitness1 >= fitness2:
new_population.append((idx1, idx2)) crossover_pair.append((idx1, idx2))
else: else:
new_population.append((idx2, idx1)) crossover_pair.append((idx2, idx1))
return new_population return crossover_pair
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size): def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):

View File

@@ -5,33 +5,10 @@ from jax import random
from jax import vmap, jit from jax import vmap, jit
def plus1(x): seed = jax.random.PRNGKey(42)
return x + 1 seed, *subkeys = random.split(seed, 3)
def minus1(x): c = random.split(seed, 1)
return x - 1 print(seed, subkeys)
print(c)
def func(rand_key, x):
r = jax.random.uniform(rand_key, shape=())
return jax.lax.cond(r > 0.5, plus1, minus1, x)
def func2(rand_key):
r = jax.random.uniform(rand_key, ())
if r < 0.3:
return 1
elif r < 0.5:
return 2
else:
return 3
key = random.PRNGKey(0)
print(func(key, 0))
batch_func = vmap(jit(func))
keys = random.split(key, 100)
print(batch_func(keys, jnp.zeros(100)))

28
examples/time_utils.py Normal file
View File

@@ -0,0 +1,28 @@
import cProfile
from io import StringIO
import pstats
def using_cprofile(func, root_abs_path=None, replace_pattern=None, save_path=None):
def inner(*args, **kwargs):
pr = cProfile.Profile()
pr.enable()
ret = func(*args, **kwargs)
pr.disable()
profile_stats = StringIO()
stats = pstats.Stats(pr, stream=profile_stats)
if root_abs_path is not None:
stats.sort_stats('cumulative').print_stats(root_abs_path)
else:
stats.sort_stats('cumulative').print_stats()
output = profile_stats.getvalue()
if replace_pattern is not None:
output = output.replace(replace_pattern, "")
if save_path is None:
print(output)
else:
with open(save_path, "w") as f:
f.write(output)
return ret
return inner

View File

@@ -1,10 +1,12 @@
from typing import Callable, List from typing import Callable, List
from functools import partial
import jax import jax
import numpy as np import numpy as np
from utils import Configer from utils import Configer
from algorithms.neat import Pipeline from algorithms.neat import Pipeline
from time_utils import using_cprofile
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]]) xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
xor_outputs = np.array([[0], [1], [1], [0]]) xor_outputs = np.array([[0], [1], [1], [0]])
@@ -17,22 +19,24 @@ def evaluate(forward_func: Callable) -> List[float]:
""" """
outs = forward_func(xor_inputs) outs = forward_func(xor_inputs)
outs = jax.device_get(outs) outs = jax.device_get(outs)
fitnesses = np.mean((outs - xor_outputs) ** 2, axis=(1, 2)) fitnesses = -np.mean((outs - xor_outputs) ** 2, axis=(1, 2))
# print(fitnesses)
return fitnesses.tolist() # returns a list return fitnesses.tolist() # returns a list
# @using_cprofile
@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
def main(): def main():
config = Configer.load_config() config = Configer.load_config()
pipeline = Pipeline(config) pipeline = Pipeline(config)
forward_func = pipeline.ask(batch=True) pipeline.auto_run(evaluate)
fitnesses = evaluate(forward_func)
pipeline.tell(fitnesses)
# for _ in range(100):
# for i in range(100): # s = time.time()
# forward_func = pipeline.ask(batch=True) # forward_func = pipeline.ask(batch=True)
# fitnesses = evaluate(forward_func) # fitnesses = evaluate(forward_func)
# pipeline.tell(fitnesses) # pipeline.tell(fitnesses)
# print(time.time() - s)
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -2,15 +2,15 @@
"basic": { "basic": {
"num_inputs": 2, "num_inputs": 2,
"num_outputs": 1, "num_outputs": 1,
"init_maximum_nodes": 20, "init_maximum_nodes": 10,
"expands_coe": 1.5 "expands_coe": 2
}, },
"neat": { "neat": {
"population": { "population": {
"fitness_criterion": "max", "fitness_criterion": "max",
"fitness_threshold": 43.9999, "fitness_threshold": 3,
"generation_limit": 100, "generation_limit": 100,
"pop_size": 1000, "pop_size": 20,
"reset_on_extinction": "False" "reset_on_extinction": "False"
}, },
"gene": { "gene": {
@@ -73,7 +73,7 @@
"node_delete_prob": 0.2 "node_delete_prob": 0.2
}, },
"species": { "species": {
"compatibility_threshold": 3.5, "compatibility_threshold": 8,
"species_fitness_func": "max", "species_fitness_func": "max",
"max_stagnation": 20, "max_stagnation": 20,
"species_elitism": 2, "species_elitism": 2,