虽然xor问题还是跑不出来,但至少已经确定不是distance的错了

This commit is contained in:
wls2002
2023-05-06 23:26:13 +08:00
parent a85e6eba78
commit 414b620dc8
5 changed files with 151 additions and 7 deletions

View File

@@ -1,4 +1,4 @@
from .genome import create_initialize_function, expand, expand_single from .genome import create_initialize_function, expand, expand_single, analysis
from .distance import distance from .distance import distance
from .mutate import create_mutate_function from .mutate import create_mutate_function
from .forward import create_forward_function from .forward import create_forward_function

View File

@@ -5,6 +5,9 @@ from numpy.typing import NDArray
from algorithms.neat.genome.utils import flatten_connections, set_operation_analysis from algorithms.neat.genome.utils import flatten_connections, set_operation_analysis
EMPTY_NODE = np.full((1, 5), np.nan)
EMPTY_CON = np.full((1, 4), np.nan)
def distance(nodes1: NDArray, connections1: NDArray, nodes2: NDArray, connections2: NDArray) -> NDArray: def distance(nodes1: NDArray, connections1: NDArray, nodes2: NDArray, connections2: NDArray) -> NDArray:
""" """
@@ -13,15 +16,70 @@ def distance(nodes1: NDArray, connections1: NDArray, nodes2: NDArray, connection
connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable] connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable]
""" """
node_distance = gene_distance(nodes1, nodes2, 'node') nd = node_distance(nodes1, nodes2) # node distance
# refactor connections # refactor connections
keys1, keys2 = nodes1[:, 0], nodes2[:, 0] keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
cons1 = flatten_connections(keys1, connections1) cons1 = flatten_connections(keys1, connections1)
cons2 = flatten_connections(keys2, connections2) cons2 = flatten_connections(keys2, connections2)
cd = connection_distance(cons1, cons2) # connection distance
return nd + cd
connection_distance = gene_distance(cons1, cons2, 'connection')
return node_distance + connection_distance def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5):
node_cnt1 = np.sum(~np.isnan(nodes1[:, 0]))
node_cnt2 = np.sum(~np.isnan(nodes2[:, 0]))
max_cnt = np.maximum(node_cnt1, node_cnt2)
nodes = np.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = np.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = np.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
nan_mask = np.isnan(nodes[:, 0])
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~nan_mask[:-1]
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * np.sum(intersect_mask)
nd = batch_homologous_node_distance(fr, sr)
nd = np.where(np.isnan(nd), 0, nd)
homologous_distance = np.sum(nd * intersect_mask)
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
if max_cnt == 0: # consider the case that both genome has no gene
return 0
else:
return val / max_cnt
def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
con_cnt1 = np.sum(~np.isnan(cons1[:, 2])) # weight is not nan, means the connection exists
con_cnt2 = np.sum(~np.isnan(cons2[:, 2]))
max_cnt = np.maximum(con_cnt1, con_cnt2)
cons = np.concatenate((cons1, cons2), axis=0)
keys = cons[:, :2]
sorted_indices = np.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = np.concatenate([cons, EMPTY_CON], axis=0) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = np.all(fr[:, :2] == sr[:, :2], axis=1) & ~np.isnan(fr[:, 2]) & ~np.isnan(sr[:, 2])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * np.sum(intersect_mask)
cd = batch_homologous_connection_distance(fr, sr)
cd = np.where(np.isnan(cd), 0, cd)
homologous_distance = np.sum(cd * intersect_mask)
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
if max_cnt == 0: # consider the case that both genome has no gene
return 0
else:
return val / max_cnt
def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.): def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):

View File

@@ -319,6 +319,10 @@ def o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections):
distances = [] distances = []
for nodes, connections in zip(pop_nodes, pop_connections): for nodes, connections in zip(pop_nodes, pop_connections):
d = distance(r_nodes, r_connections, nodes, connections) d = distance(r_nodes, r_connections, nodes, connections)
if d < 0:
d = distance(r_nodes, r_connections, nodes, connections)
print(d)
assert False
distances.append(d) distances.append(d)
distances = np.stack(distances, axis=0) distances = np.stack(distances, axis=0)
return distances return distances

84
examples/distance_test.py Normal file
View File

@@ -0,0 +1,84 @@
from typing import Callable, List
from functools import partial
import numpy as np
from utils import Configer
from algorithms.neat.genome.numpy import analysis, distance
from algorithms.neat.genome.numpy import create_initialize_function, create_mutate_function
def real_distance(nodes1, connections1, nodes2, connections2, input_idx, output_idx):
nodes1, connections1 = analysis(nodes1, connections1, input_idx, output_idx)
nodes2, connections2 = analysis(nodes2, connections2, input_idx, output_idx)
compatibility_coe = 0.5
disjoint_coe = 1.
node_distance = 0.0
if nodes1 or nodes2: # otherwise, both are empty
disjoint_nodes = 0
for k2 in nodes2:
if k2 not in nodes1:
disjoint_nodes += 1
for k1, n1 in nodes1.items():
n2 = nodes2.get(k1)
if n2 is None:
disjoint_nodes += 1
else:
if n1[0] is None:
continue
d = abs(n1[0] - n2[0]) + abs(n1[1] - n2[1])
d += 1 if n1[2] != n2[2] else 0
d += 1 if n1[3] != n2[3] else 0
node_distance += d
max_nodes = max(len(nodes1), len(nodes2))
node_distance = (compatibility_coe * node_distance + disjoint_coe * disjoint_nodes) / max_nodes
connection_distance = 0.0
if connections1 or connections2:
disjoint_connections = 0
for k2 in connections2:
if k2 not in connections1:
disjoint_connections += 1
for k1, c1 in connections1.items():
c2 = connections2.get(k1)
if c2 is None:
disjoint_connections += 1
else:
# Homologous genes compute their own distance value.
d = abs(c1[0] - c2[0])
d += 1 if c1[1] != c2[1] else 0
connection_distance += d
max_conn = max(len(connections1), len(connections2))
connection_distance = (compatibility_coe * connection_distance + disjoint_coe * disjoint_connections) / max_conn
return node_distance + connection_distance
def main():
config = Configer.load_config()
keys_idx = config.basic.num_inputs + config.basic.num_outputs
pop_size = config.neat.population.pop_size
init_func = create_initialize_function(config)
pop_nodes, pop_connections, input_idx, output_idx = init_func()
mutate_func = create_mutate_function(config, input_idx, output_idx, batch=True)
while True:
pop_nodes, pop_connections = mutate_func(pop_nodes, pop_connections, list(range(keys_idx, keys_idx + pop_size)))
keys_idx += pop_size
for i in range(pop_size):
for j in range(pop_size):
nodes1, connections1 = pop_nodes[i], pop_connections[i]
nodes2, connections2 = pop_nodes[j], pop_connections[j]
numpy_d = distance(nodes1, connections1, nodes2, connections2)
real_d = real_distance(nodes1, connections1, nodes2, connections2, input_idx, output_idx)
assert np.isclose(numpy_d, real_d), f'{numpy_d} != {real_d}'
print(numpy_d, real_d)
if __name__ == '__main__':
np.random.seed(0)
main()

View File

@@ -1,7 +1,6 @@
from typing import Callable, List from typing import Callable, List
from functools import partial from functools import partial
import jax
import numpy as np import numpy as np
from utils import Configer from utils import Configer
@@ -18,8 +17,7 @@ def evaluate(forward_func: Callable) -> List[float]:
:return: :return:
""" """
outs = forward_func(xor_inputs) outs = forward_func(xor_inputs)
outs = jax.device_get(outs) fitnesses = np.mean((outs - xor_outputs) ** 2, axis=(1, 2))
fitnesses = -np.mean((outs - xor_outputs) ** 2, axis=(1, 2))
# print(fitnesses) # print(fitnesses)
return fitnesses.tolist() # returns a list return fitnesses.tolist() # returns a list