This commit is contained in:
wls2002
2023-05-06 18:33:30 +08:00
parent 73ac1bcfe0
commit 14fed83193
10 changed files with 206 additions and 22 deletions

View File

@@ -42,14 +42,14 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
# crossover nodes # crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0] keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
nodes2 = align_array(keys1, keys2, nodes2, 'node') nodes2 = align_array(keys1, keys2, nodes2, 'node')
new_nodes = jnp.where(jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2)) new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections # crossover connections
cons1 = flatten_connections(keys1, connections1) cons1 = flatten_connections(keys1, connections1)
cons2 = flatten_connections(keys2, connections2) cons2 = flatten_connections(keys2, connections2)
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2] con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection') cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
new_cons = jnp.where(jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2)) new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
new_cons = unflatten_connections(len(keys1), new_cons) new_cons = unflatten_connections(len(keys1), new_cons)
return new_nodes, new_cons return new_nodes, new_cons

View File

@@ -35,6 +35,15 @@ def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2) n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2)
nodes = jnp.concatenate((ar1, ar2), axis=0) nodes = jnp.concatenate((ar1, ar2), axis=0)
sorted_nodes = nodes[n_sorted_indices] sorted_nodes = nodes[n_sorted_indices]
if gene_type == 'node':
node_exist_mask = jnp.any(~jnp.isnan(sorted_nodes[:, 1:]), axis=1)
else:
node_exist_mask = jnp.any(~jnp.isnan(sorted_nodes[:, 2:]), axis=1)
n_intersect_mask = n_intersect_mask & node_exist_mask
n_union_mask = n_union_mask & node_exist_mask
fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:] fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:]
non_homologous_cnt = jnp.sum(n_union_mask) - jnp.sum(n_intersect_mask) non_homologous_cnt = jnp.sum(n_union_mask) - jnp.sum(n_intersect_mask)
@@ -48,9 +57,11 @@ def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
gene_cnt1 = jnp.sum(jnp.all(~jnp.isnan(ar1), axis=1)) gene_cnt1 = jnp.sum(jnp.all(~jnp.isnan(ar1), axis=1))
gene_cnt2 = jnp.sum(jnp.all(~jnp.isnan(ar2), axis=1)) gene_cnt2 = jnp.sum(jnp.all(~jnp.isnan(ar2), axis=1))
max_cnt = jnp.maximum(gene_cnt1, gene_cnt2)
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
return val / jnp.where(gene_cnt1 > gene_cnt2, gene_cnt1, gene_cnt2)
return jnp.where(max_cnt == 0, 0, val / max_cnt) # consider the case that both genome has no gene
@partial(vmap, in_axes=(0, 0)) @partial(vmap, in_axes=(0, 0))

View File

@@ -27,7 +27,7 @@ def create_forward_function(nodes: NDArray, connections: NDArray,
""" """
if debug: if debug:
cal_seqs = topological_sort(nodes, connections) cal_seqs = topological_sort_debug(nodes, connections)
return lambda inputs: forward_single_debug(inputs, N, input_idx, output_idx, return lambda inputs: forward_single_debug(inputs, N, input_idx, output_idx,
cal_seqs, nodes, connections) cal_seqs, nodes, connections)

View File

@@ -12,9 +12,10 @@ status.
Empty nodes or connections are represented using np.nan. Empty nodes or connections are represented using np.nan.
""" """
from typing import Tuple from typing import Tuple, Dict
from functools import partial from functools import partial
import jax
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
from jax import numpy as jnp from jax import numpy as jnp
@@ -131,6 +132,83 @@ def expand_single(nodes: NDArray, connections: NDArray, new_N: int) -> Tuple[NDA
return new_nodes, new_connections return new_nodes, new_connections
def analysis(nodes: NDArray, connections: NDArray, input_keys, output_keys) -> \
Tuple[Dict[int, Tuple[float, float, int, int]], Dict[Tuple[int, int], Tuple[float, bool]]]:
"""
Convert a genome from array to dict.
:param nodes: (N, 5)
:param connections: (2, N, N)
:param output_keys:
:param input_keys:
:return: nodes_dict[key: (bias, response, act, agg)], connections_dict[(f_key, t_key): (weight, enabled)]
"""
# update nodes_dict
try:
nodes_dict = {}
idx2key = {}
for i, node in enumerate(nodes):
if np.isnan(node[0]):
continue
key = int(node[0])
assert key not in nodes_dict, f"Duplicate node key: {key}!"
bias = node[1] if not np.isnan(node[1]) else None
response = node[2] if not np.isnan(node[2]) else None
act = node[3] if not np.isnan(node[3]) else None
agg = node[4] if not np.isnan(node[4]) else None
nodes_dict[key] = (bias, response, act, agg)
idx2key[i] = key
# check nodes_dict
for i in input_keys:
assert i in nodes_dict, f"Input node {i} not found in nodes_dict!"
bias, response, act, agg = nodes_dict[i]
assert bias is None and response is None and act is None and agg is None, \
f"Input node {i} must has None bias, response, act, or agg!"
for o in output_keys:
assert o in nodes_dict, f"Output node {o} not found in nodes_dict!"
for k, v in nodes_dict.items():
if k not in input_keys:
bias, response, act, agg = v
assert bias is not None and response is not None and act is not None and agg is not None, \
f"Normal node {k} must has non-None bias, response, act, or agg!"
# update connections
connections_dict = {}
for i in range(connections.shape[1]):
for j in range(connections.shape[2]):
if np.isnan(connections[0, i, j]) and np.isnan(connections[1, i, j]):
continue
assert i in idx2key, f"Node index {i} not found in idx2key:{idx2key}!"
assert j in idx2key, f"Node index {j} not found in idx2key:{idx2key}!"
key = (idx2key[i], idx2key[j])
weight = connections[0, i, j] if not np.isnan(connections[0, i, j]) else None
enabled = (connections[1, i, j] == 1) if not np.isnan(connections[1, i, j]) else None
assert weight is not None, f"Connection {key} must has non-None weight!"
assert enabled is not None, f"Connection {key} must has non-None enabled!"
connections_dict[key] = (weight, enabled)
return nodes_dict, connections_dict
except AssertionError:
print(nodes)
print(connections)
raise AssertionError
def pop_analysis(pop_nodes, pop_connections, input_keys, output_keys):
pop_nodes, pop_connections = jax.device_get((pop_nodes, pop_connections))
res = []
for nodes, connections in zip(pop_nodes, pop_connections):
res.append(analysis(nodes, connections, input_keys, output_keys))
return res
@jit @jit
def add_node(new_node_key: int, nodes: Array, connections: Array, def add_node(new_node_key: int, nodes: Array, connections: Array,
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]: bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]:
@@ -158,11 +236,12 @@ def delete_node_by_idx(idx: int, nodes: Array, connections: Array) -> Tuple[Arra
""" """
delete a node from the genome. only delete the node, regardless of connections. delete a node from the genome. only delete the node, regardless of connections.
""" """
node_keys = nodes[:, 0] # node_keys = nodes[:, 0]
nodes = nodes.at[idx].set(EMPTY_NODE)
# move the last node to the deleted node's position # move the last node to the deleted node's position
last_real_idx = fetch_last(~jnp.isnan(node_keys)) # last_real_idx = fetch_last(~jnp.isnan(node_keys))
nodes = nodes.at[idx].set(nodes[last_real_idx]) # nodes = nodes.at[idx].set(nodes[last_real_idx])
nodes = nodes.at[last_real_idx].set(EMPTY_NODE) # nodes = nodes.at[last_real_idx].set(EMPTY_NODE)
return nodes, connections return nodes, connections
@@ -206,7 +285,3 @@ def delete_connection_by_idx(from_idx: int, to_idx: int, nodes: Array, connectio
""" """
connections = connections.at[:, from_idx, to_idx].set(np.nan) connections = connections.at[:, from_idx, to_idx].set(np.nan)
return nodes, connections return nodes, connections
# if __name__ == '__main__':
# pop_nodes, pop_connections, input_keys, output_keys = initialize_genomes(100, 5, 2, 1)
# print(pop_nodes, pop_connections)

View File

@@ -148,7 +148,9 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
check_cycles(nodes, connections, 0, 3) -> False check_cycles(nodes, connections, 0, 3) -> False
check_cycles(nodes, connections, 1, 0) -> False check_cycles(nodes, connections, 1, 0) -> False
""" """
connections_enable = connections[1, :, :] == 1 # connections_enable = connections[0, :, :] == 1
connections_enable = ~jnp.isnan(connections[0, :, :])
connections_enable = connections_enable.at[from_idx, to_idx].set(True) connections_enable = connections_enable.at[from_idx, to_idx].set(True)
nodes_visited = jnp.full(nodes.shape[0], False) nodes_visited = jnp.full(nodes.shape[0], False)
nodes_visited = nodes_visited.at[to_idx].set(True) nodes_visited = nodes_visited.at[to_idx].set(True)

View File

@@ -7,7 +7,9 @@ import numpy as np
from .species import SpeciesController from .species import SpeciesController
from .genome import create_initialize_function, create_mutate_function, create_forward_function from .genome import create_initialize_function, create_mutate_function, create_forward_function
from .genome import batch_crossover from .genome import batch_crossover
from .genome.crossover import crossover
from .genome import expand, expand_single from .genome import expand, expand_single
from algorithms.neat.genome.genome import pop_analysis, analysis
class Pipeline: class Pipeline:
@@ -51,12 +53,22 @@ class Pipeline:
def tell(self, fitnesses): def tell(self, fitnesses):
self.generation += 1 self.generation += 1
for i, f in enumerate(fitnesses):
if np.isnan(f):
print("fuck!!!!!!!!!!!!!!")
error_nodes, error_connections = self.pop_nodes[i], self.pop_connections[i]
np.save('error_nodes.npy', error_nodes)
np.save('error_connections.npy', error_connections)
assert False
self.species_controller.update_species_fitnesses(fitnesses) self.species_controller.update_species_fitnesses(fitnesses)
crossover_pair = self.species_controller.reproduce(self.generation) crossover_pair = self.species_controller.reproduce(self.generation)
self.update_next_generation(crossover_pair) self.update_next_generation(crossover_pair)
# print(pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx))
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation) self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation)
self.expand() self.expand()
@@ -103,16 +115,22 @@ class Pipeline:
crossover_rand_keys = jax.random.split(k1, self.pop_size) crossover_rand_keys = jax.random.split(k1, self.pop_size)
# npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections # npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
npn, npc = crossover_wrapper(crossover_rand_keys, wpn, wpc, lpn, lpc) npn, npc = crossover_wrapper(crossover_rand_keys, wpn, wpc, lpn, lpc)
# print(pop_analysis(npn, npc, self.input_idx, self.output_idx))
# mutate # mutate
new_node_keys = np.array(self.fetch_new_node_keys()) new_node_keys = np.array(self.fetch_new_node_keys())
mutate_rand_keys = jax.random.split(k2, self.pop_size) mutate_rand_keys = jax.random.split(k2, self.pop_size)
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc) m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc)
# print(pop_analysis(m_npn, m_npc, self.input_idx, self.output_idx))
# elitism don't mutate # elitism don't mutate
# (pop_size, ) to (pop_size, 1, 1) # (pop_size, ) to (pop_size, 1, 1)
self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn) self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn)
# (pop_size, ) to (pop_size, 1, 1, 1) # (pop_size, ) to (pop_size, 1, 1, 1)
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc) self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
# print(pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx))
# recycle unused node keys # recycle unused node keys
unused = [] unused = []
@@ -138,8 +156,8 @@ class Pipeline:
self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N) self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N)
# don't forget to expand representation genome in species # don't forget to expand representation genome in species
for s in self.species_controller.species: for s in self.species_controller.species.values():
s.representative = expand(*s.representative, self.N) s.representative = expand_single(*s.representative, self.N)
def fetch_new_node_keys(self): def fetch_new_node_keys(self):
# if remain unused keys are not enough, create new keys # if remain unused keys are not enough, create new keys
@@ -164,6 +182,19 @@ class Pipeline:
print(f"Generation: {self.generation}", print(f"Generation: {self.generation}",
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}") f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")
# def crossover_wrapper(self, crossover_rand_keys, wpn, wpc, lpn, lpc):
# pop_nodes, pop_connections = [], []
# for randkey, wn, wc, ln, lc in zip(crossover_rand_keys, wpn, wpc, lpn, lpc):
# new_nodes, new_connections = crossover(randkey, wn, wc, ln, lc)
# pop_nodes.append(new_nodes)
# pop_connections.append(new_connections)
# try:
# print(analysis(new_nodes, new_connections, self.input_idx, self.output_idx))
# except AssertionError:
# new_nodes, new_connections = crossover(randkey, wn, wc, ln, lc)
# return np.stack(pop_nodes), np.stack(pop_connections)
# return batch_crossover(*args)
def crossover_wrapper(*args): def crossover_wrapper(*args):
return batch_crossover(*args) return batch_crossover(*args)

View File

@@ -67,7 +67,7 @@ class SpeciesController:
for sid, species in self.species.items(): for sid, species in self.species.items():
# calculate the distance between the representative and the population # calculate the distance between the representative and the population
r_nodes, r_connections = species.representative r_nodes, r_connections = species.representative
distances = self.o2m_distance_func(r_nodes, r_connections, pop_nodes, pop_connections) distances = self.o2m_distance_wrapper(r_nodes, r_connections, pop_nodes, pop_connections)
distances = jax.device_get(distances) # fetch the data from gpu distances = jax.device_get(distances) # fetch the data from gpu
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
@@ -82,7 +82,7 @@ class SpeciesController:
rid_list = [new_representatives[sid] for sid in previous_species_list] rid_list = [new_representatives[sid] for sid in previous_species_list]
res_pop_distance = [ res_pop_distance = [
jax.device_get( jax.device_get(
self.o2m_distance_func(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections) self.o2m_distance_wrapper(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
) )
for rid in rid_list for rid in rid_list
] ]
@@ -110,7 +110,7 @@ class SpeciesController:
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()])) sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
distances = [ distances = [
self.o2o_distance_func(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r]) self.o2o_distance_wrapper(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
for r in rid for r in rid
] ]
distances = np.array(distances) distances = np.array(distances)
@@ -267,6 +267,36 @@ class SpeciesController:
return crossover_pair return crossover_pair
def o2m_distance_wrapper(self, r_nodes, r_connections, pop_nodes, pop_connections):
# distances = self.o2m_distance_func(r_nodes, r_connections, pop_nodes, pop_connections)
# for d in distances:
# if np.isnan(d):
# print("fuck!!!!!!!!!!!!!!")
# print(distances)
# assert False
# return distances
distances = []
for nodes, connections in zip(pop_nodes, pop_connections):
d = self.o2o_distance_func(r_nodes, r_connections, nodes, connections)
if np.isnan(d) or d > 20:
np.save("too_large_distance_r_nodes.npy", r_nodes)
np.save("too_large_distance_r_connections.npy", r_connections)
np.save("too_large_distance_nodes", nodes)
np.save("too_large_distance_connections.npy", connections)
d = self.o2o_distance_func(r_nodes, r_connections, nodes, connections)
assert False
distances.append(d)
distances = np.stack(distances, axis=0)
# print(distances)
return distances
def o2o_distance_wrapper(self, *keys):
d = self.o2o_distance_func(*keys)
if np.isnan(d):
print("fuck!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
assert False
return d
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size): def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
""" """

View File

@@ -0,0 +1,24 @@
import numpy as np
from jax import numpy as jnp
from algorithms.neat.genome.genome import analysis
from algorithms.neat.genome import create_forward_function
error_nodes = np.load('error_nodes.npy')
error_connections = np.load('error_connections.npy')
node_dict, connection_dict = analysis(error_nodes, error_connections, np.array([0, 1]), np.array([2, ]))
print(node_dict, connection_dict, sep='\n')
N = error_nodes.shape[0]
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
func = create_forward_function(error_nodes, error_connections, N, jnp.array([0, 1]), jnp.array([2, ]),
batch=True, debug=True)
out = func(np.array([1, 0]))
print(error_nodes)
print(error_connections)
print(out)

View File

@@ -0,0 +1,11 @@
import numpy as np
from algorithms.neat.genome import distance
r_nodes = np.load('too_large_distance_r_nodes.npy')
r_connections = np.load('too_large_distance_r_connections.npy')
nodes = np.load('too_large_distance_nodes.npy')
connections = np.load('too_large_distance_connections.npy')
d1 = distance(r_nodes, r_connections, nodes, connections)
d2 = distance(nodes, connections, r_nodes, r_connections)
print(d1, d2)

View File

@@ -10,7 +10,7 @@
"fitness_criterion": "max", "fitness_criterion": "max",
"fitness_threshold": 3, "fitness_threshold": 3,
"generation_limit": 100, "generation_limit": 100,
"pop_size": 20, "pop_size": 100,
"reset_on_extinction": "False" "reset_on_extinction": "False"
}, },
"gene": { "gene": {
@@ -73,7 +73,7 @@
"node_delete_prob": 0.2 "node_delete_prob": 0.2
}, },
"species": { "species": {
"compatibility_threshold": 8, "compatibility_threshold": 3.5,
"species_fitness_func": "max", "species_fitness_func": "max",
"max_stagnation": 20, "max_stagnation": 20,
"species_elitism": 2, "species_elitism": 2,