This commit is contained in:
wls2002
2023-05-06 18:33:30 +08:00
parent 73ac1bcfe0
commit 14fed83193
10 changed files with 206 additions and 22 deletions

View File

@@ -42,14 +42,14 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
nodes2 = align_array(keys1, keys2, nodes2, 'node')
new_nodes = jnp.where(jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections
cons1 = flatten_connections(keys1, connections1)
cons2 = flatten_connections(keys2, connections2)
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
new_cons = jnp.where(jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
new_cons = unflatten_connections(len(keys1), new_cons)
return new_nodes, new_cons

View File

@@ -35,6 +35,15 @@ def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2)
nodes = jnp.concatenate((ar1, ar2), axis=0)
sorted_nodes = nodes[n_sorted_indices]
if gene_type == 'node':
node_exist_mask = jnp.any(~jnp.isnan(sorted_nodes[:, 1:]), axis=1)
else:
node_exist_mask = jnp.any(~jnp.isnan(sorted_nodes[:, 2:]), axis=1)
n_intersect_mask = n_intersect_mask & node_exist_mask
n_union_mask = n_union_mask & node_exist_mask
fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:]
non_homologous_cnt = jnp.sum(n_union_mask) - jnp.sum(n_intersect_mask)
@@ -48,9 +57,11 @@ def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
gene_cnt1 = jnp.sum(jnp.all(~jnp.isnan(ar1), axis=1))
gene_cnt2 = jnp.sum(jnp.all(~jnp.isnan(ar2), axis=1))
max_cnt = jnp.maximum(gene_cnt1, gene_cnt2)
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
return val / jnp.where(gene_cnt1 > gene_cnt2, gene_cnt1, gene_cnt2)
return jnp.where(max_cnt == 0, 0, val / max_cnt) # consider the case that both genome has no gene
@partial(vmap, in_axes=(0, 0))

View File

@@ -27,7 +27,7 @@ def create_forward_function(nodes: NDArray, connections: NDArray,
"""
if debug:
cal_seqs = topological_sort(nodes, connections)
cal_seqs = topological_sort_debug(nodes, connections)
return lambda inputs: forward_single_debug(inputs, N, input_idx, output_idx,
cal_seqs, nodes, connections)

View File

@@ -12,9 +12,10 @@ status.
Empty nodes or connections are represented using np.nan.
"""
from typing import Tuple
from typing import Tuple, Dict
from functools import partial
import jax
import numpy as np
from numpy.typing import NDArray
from jax import numpy as jnp
@@ -131,6 +132,83 @@ def expand_single(nodes: NDArray, connections: NDArray, new_N: int) -> Tuple[NDA
return new_nodes, new_connections
def analysis(nodes: NDArray, connections: NDArray, input_keys, output_keys) -> \
Tuple[Dict[int, Tuple[float, float, int, int]], Dict[Tuple[int, int], Tuple[float, bool]]]:
"""
Convert a genome from array to dict.
:param nodes: (N, 5)
:param connections: (2, N, N)
:param output_keys:
:param input_keys:
:return: nodes_dict[key: (bias, response, act, agg)], connections_dict[(f_key, t_key): (weight, enabled)]
"""
# update nodes_dict
try:
nodes_dict = {}
idx2key = {}
for i, node in enumerate(nodes):
if np.isnan(node[0]):
continue
key = int(node[0])
assert key not in nodes_dict, f"Duplicate node key: {key}!"
bias = node[1] if not np.isnan(node[1]) else None
response = node[2] if not np.isnan(node[2]) else None
act = node[3] if not np.isnan(node[3]) else None
agg = node[4] if not np.isnan(node[4]) else None
nodes_dict[key] = (bias, response, act, agg)
idx2key[i] = key
# check nodes_dict
for i in input_keys:
assert i in nodes_dict, f"Input node {i} not found in nodes_dict!"
bias, response, act, agg = nodes_dict[i]
assert bias is None and response is None and act is None and agg is None, \
f"Input node {i} must has None bias, response, act, or agg!"
for o in output_keys:
assert o in nodes_dict, f"Output node {o} not found in nodes_dict!"
for k, v in nodes_dict.items():
if k not in input_keys:
bias, response, act, agg = v
assert bias is not None and response is not None and act is not None and agg is not None, \
f"Normal node {k} must has non-None bias, response, act, or agg!"
# update connections
connections_dict = {}
for i in range(connections.shape[1]):
for j in range(connections.shape[2]):
if np.isnan(connections[0, i, j]) and np.isnan(connections[1, i, j]):
continue
assert i in idx2key, f"Node index {i} not found in idx2key:{idx2key}!"
assert j in idx2key, f"Node index {j} not found in idx2key:{idx2key}!"
key = (idx2key[i], idx2key[j])
weight = connections[0, i, j] if not np.isnan(connections[0, i, j]) else None
enabled = (connections[1, i, j] == 1) if not np.isnan(connections[1, i, j]) else None
assert weight is not None, f"Connection {key} must has non-None weight!"
assert enabled is not None, f"Connection {key} must has non-None enabled!"
connections_dict[key] = (weight, enabled)
return nodes_dict, connections_dict
except AssertionError:
print(nodes)
print(connections)
raise AssertionError
def pop_analysis(pop_nodes, pop_connections, input_keys, output_keys):
pop_nodes, pop_connections = jax.device_get((pop_nodes, pop_connections))
res = []
for nodes, connections in zip(pop_nodes, pop_connections):
res.append(analysis(nodes, connections, input_keys, output_keys))
return res
@jit
def add_node(new_node_key: int, nodes: Array, connections: Array,
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]:
@@ -158,11 +236,12 @@ def delete_node_by_idx(idx: int, nodes: Array, connections: Array) -> Tuple[Arra
"""
delete a node from the genome. only delete the node, regardless of connections.
"""
node_keys = nodes[:, 0]
# node_keys = nodes[:, 0]
nodes = nodes.at[idx].set(EMPTY_NODE)
# move the last node to the deleted node's position
last_real_idx = fetch_last(~jnp.isnan(node_keys))
nodes = nodes.at[idx].set(nodes[last_real_idx])
nodes = nodes.at[last_real_idx].set(EMPTY_NODE)
# last_real_idx = fetch_last(~jnp.isnan(node_keys))
# nodes = nodes.at[idx].set(nodes[last_real_idx])
# nodes = nodes.at[last_real_idx].set(EMPTY_NODE)
return nodes, connections
@@ -206,7 +285,3 @@ def delete_connection_by_idx(from_idx: int, to_idx: int, nodes: Array, connectio
"""
connections = connections.at[:, from_idx, to_idx].set(np.nan)
return nodes, connections
# if __name__ == '__main__':
# pop_nodes, pop_connections, input_keys, output_keys = initialize_genomes(100, 5, 2, 1)
# print(pop_nodes, pop_connections)

View File

@@ -148,7 +148,9 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
check_cycles(nodes, connections, 0, 3) -> False
check_cycles(nodes, connections, 1, 0) -> False
"""
connections_enable = connections[1, :, :] == 1
# connections_enable = connections[0, :, :] == 1
connections_enable = ~jnp.isnan(connections[0, :, :])
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
nodes_visited = jnp.full(nodes.shape[0], False)
nodes_visited = nodes_visited.at[to_idx].set(True)

View File

@@ -7,7 +7,9 @@ import numpy as np
from .species import SpeciesController
from .genome import create_initialize_function, create_mutate_function, create_forward_function
from .genome import batch_crossover
from .genome.crossover import crossover
from .genome import expand, expand_single
from algorithms.neat.genome.genome import pop_analysis, analysis
class Pipeline:
@@ -51,12 +53,22 @@ class Pipeline:
def tell(self, fitnesses):
self.generation += 1
for i, f in enumerate(fitnesses):
if np.isnan(f):
print("fuck!!!!!!!!!!!!!!")
error_nodes, error_connections = self.pop_nodes[i], self.pop_connections[i]
np.save('error_nodes.npy', error_nodes)
np.save('error_connections.npy', error_connections)
assert False
self.species_controller.update_species_fitnesses(fitnesses)
crossover_pair = self.species_controller.reproduce(self.generation)
self.update_next_generation(crossover_pair)
# print(pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx))
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation)
self.expand()
@@ -103,16 +115,22 @@ class Pipeline:
crossover_rand_keys = jax.random.split(k1, self.pop_size)
# npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
npn, npc = crossover_wrapper(crossover_rand_keys, wpn, wpc, lpn, lpc)
# print(pop_analysis(npn, npc, self.input_idx, self.output_idx))
# mutate
new_node_keys = np.array(self.fetch_new_node_keys())
mutate_rand_keys = jax.random.split(k2, self.pop_size)
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys)
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc)
# print(pop_analysis(m_npn, m_npc, self.input_idx, self.output_idx))
# elitism don't mutate
# (pop_size, ) to (pop_size, 1, 1)
self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn)
# (pop_size, ) to (pop_size, 1, 1, 1)
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
# print(pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx))
# recycle unused node keys
unused = []
@@ -138,8 +156,8 @@ class Pipeline:
self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N)
# don't forget to expand representation genome in species
for s in self.species_controller.species:
s.representative = expand(*s.representative, self.N)
for s in self.species_controller.species.values():
s.representative = expand_single(*s.representative, self.N)
def fetch_new_node_keys(self):
# if remain unused keys are not enough, create new keys
@@ -164,6 +182,19 @@ class Pipeline:
print(f"Generation: {self.generation}",
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")
# def crossover_wrapper(self, crossover_rand_keys, wpn, wpc, lpn, lpc):
# pop_nodes, pop_connections = [], []
# for randkey, wn, wc, ln, lc in zip(crossover_rand_keys, wpn, wpc, lpn, lpc):
# new_nodes, new_connections = crossover(randkey, wn, wc, ln, lc)
# pop_nodes.append(new_nodes)
# pop_connections.append(new_connections)
# try:
# print(analysis(new_nodes, new_connections, self.input_idx, self.output_idx))
# except AssertionError:
# new_nodes, new_connections = crossover(randkey, wn, wc, ln, lc)
# return np.stack(pop_nodes), np.stack(pop_connections)
# return batch_crossover(*args)
def crossover_wrapper(*args):
return batch_crossover(*args)

View File

@@ -67,7 +67,7 @@ class SpeciesController:
for sid, species in self.species.items():
# calculate the distance between the representative and the population
r_nodes, r_connections = species.representative
distances = self.o2m_distance_func(r_nodes, r_connections, pop_nodes, pop_connections)
distances = self.o2m_distance_wrapper(r_nodes, r_connections, pop_nodes, pop_connections)
distances = jax.device_get(distances) # fetch the data from gpu
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
@@ -82,7 +82,7 @@ class SpeciesController:
rid_list = [new_representatives[sid] for sid in previous_species_list]
res_pop_distance = [
jax.device_get(
self.o2m_distance_func(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
self.o2m_distance_wrapper(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
)
for rid in rid_list
]
@@ -110,7 +110,7 @@ class SpeciesController:
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
distances = [
self.o2o_distance_func(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
self.o2o_distance_wrapper(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
for r in rid
]
distances = np.array(distances)
@@ -267,6 +267,36 @@ class SpeciesController:
return crossover_pair
def o2m_distance_wrapper(self, r_nodes, r_connections, pop_nodes, pop_connections):
# distances = self.o2m_distance_func(r_nodes, r_connections, pop_nodes, pop_connections)
# for d in distances:
# if np.isnan(d):
# print("fuck!!!!!!!!!!!!!!")
# print(distances)
# assert False
# return distances
distances = []
for nodes, connections in zip(pop_nodes, pop_connections):
d = self.o2o_distance_func(r_nodes, r_connections, nodes, connections)
if np.isnan(d) or d > 20:
np.save("too_large_distance_r_nodes.npy", r_nodes)
np.save("too_large_distance_r_connections.npy", r_connections)
np.save("too_large_distance_nodes", nodes)
np.save("too_large_distance_connections.npy", connections)
d = self.o2o_distance_func(r_nodes, r_connections, nodes, connections)
assert False
distances.append(d)
distances = np.stack(distances, axis=0)
# print(distances)
return distances
def o2o_distance_wrapper(self, *keys):
d = self.o2o_distance_func(*keys)
if np.isnan(d):
print("fuck!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
assert False
return d
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
"""

View File

@@ -0,0 +1,24 @@
import numpy as np
from jax import numpy as jnp
from algorithms.neat.genome.genome import analysis
from algorithms.neat.genome import create_forward_function
error_nodes = np.load('error_nodes.npy')
error_connections = np.load('error_connections.npy')
node_dict, connection_dict = analysis(error_nodes, error_connections, np.array([0, 1]), np.array([2, ]))
print(node_dict, connection_dict, sep='\n')
N = error_nodes.shape[0]
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
func = create_forward_function(error_nodes, error_connections, N, jnp.array([0, 1]), jnp.array([2, ]),
batch=True, debug=True)
out = func(np.array([1, 0]))
print(error_nodes)
print(error_connections)
print(out)

View File

@@ -0,0 +1,11 @@
import numpy as np
from algorithms.neat.genome import distance
r_nodes = np.load('too_large_distance_r_nodes.npy')
r_connections = np.load('too_large_distance_r_connections.npy')
nodes = np.load('too_large_distance_nodes.npy')
connections = np.load('too_large_distance_connections.npy')
d1 = distance(r_nodes, r_connections, nodes, connections)
d2 = distance(nodes, connections, r_nodes, r_connections)
print(d1, d2)

View File

@@ -10,7 +10,7 @@
"fitness_criterion": "max",
"fitness_threshold": 3,
"generation_limit": 100,
"pop_size": 20,
"pop_size": 100,
"reset_on_extinction": "False"
},
"gene": {
@@ -73,7 +73,7 @@
"node_delete_prob": 0.2
},
"species": {
"compatibility_threshold": 8,
"compatibility_threshold": 3.5,
"species_fitness_func": "max",
"max_stagnation": 20,
"species_elitism": 2,