debuging
This commit is contained in:
@@ -42,14 +42,14 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
|
|||||||
# crossover nodes
|
# crossover nodes
|
||||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||||
nodes2 = align_array(keys1, keys2, nodes2, 'node')
|
nodes2 = align_array(keys1, keys2, nodes2, 'node')
|
||||||
new_nodes = jnp.where(jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
|
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
|
||||||
|
|
||||||
# crossover connections
|
# crossover connections
|
||||||
cons1 = flatten_connections(keys1, connections1)
|
cons1 = flatten_connections(keys1, connections1)
|
||||||
cons2 = flatten_connections(keys2, connections2)
|
cons2 = flatten_connections(keys2, connections2)
|
||||||
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
|
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
|
||||||
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
|
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
|
||||||
new_cons = jnp.where(jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
|
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
|
||||||
new_cons = unflatten_connections(len(keys1), new_cons)
|
new_cons = unflatten_connections(len(keys1), new_cons)
|
||||||
|
|
||||||
return new_nodes, new_cons
|
return new_nodes, new_cons
|
||||||
|
|||||||
@@ -35,6 +35,15 @@ def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
|
|||||||
n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2)
|
n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2)
|
||||||
nodes = jnp.concatenate((ar1, ar2), axis=0)
|
nodes = jnp.concatenate((ar1, ar2), axis=0)
|
||||||
sorted_nodes = nodes[n_sorted_indices]
|
sorted_nodes = nodes[n_sorted_indices]
|
||||||
|
|
||||||
|
if gene_type == 'node':
|
||||||
|
node_exist_mask = jnp.any(~jnp.isnan(sorted_nodes[:, 1:]), axis=1)
|
||||||
|
else:
|
||||||
|
node_exist_mask = jnp.any(~jnp.isnan(sorted_nodes[:, 2:]), axis=1)
|
||||||
|
|
||||||
|
n_intersect_mask = n_intersect_mask & node_exist_mask
|
||||||
|
n_union_mask = n_union_mask & node_exist_mask
|
||||||
|
|
||||||
fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:]
|
fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:]
|
||||||
|
|
||||||
non_homologous_cnt = jnp.sum(n_union_mask) - jnp.sum(n_intersect_mask)
|
non_homologous_cnt = jnp.sum(n_union_mask) - jnp.sum(n_intersect_mask)
|
||||||
@@ -48,9 +57,11 @@ def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
|
|||||||
|
|
||||||
gene_cnt1 = jnp.sum(jnp.all(~jnp.isnan(ar1), axis=1))
|
gene_cnt1 = jnp.sum(jnp.all(~jnp.isnan(ar1), axis=1))
|
||||||
gene_cnt2 = jnp.sum(jnp.all(~jnp.isnan(ar2), axis=1))
|
gene_cnt2 = jnp.sum(jnp.all(~jnp.isnan(ar2), axis=1))
|
||||||
|
max_cnt = jnp.maximum(gene_cnt1, gene_cnt2)
|
||||||
|
|
||||||
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
|
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
|
||||||
return val / jnp.where(gene_cnt1 > gene_cnt2, gene_cnt1, gene_cnt2)
|
|
||||||
|
return jnp.where(max_cnt == 0, 0, val / max_cnt) # consider the case that both genome has no gene
|
||||||
|
|
||||||
|
|
||||||
@partial(vmap, in_axes=(0, 0))
|
@partial(vmap, in_axes=(0, 0))
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ def create_forward_function(nodes: NDArray, connections: NDArray,
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
cal_seqs = topological_sort(nodes, connections)
|
cal_seqs = topological_sort_debug(nodes, connections)
|
||||||
return lambda inputs: forward_single_debug(inputs, N, input_idx, output_idx,
|
return lambda inputs: forward_single_debug(inputs, N, input_idx, output_idx,
|
||||||
cal_seqs, nodes, connections)
|
cal_seqs, nodes, connections)
|
||||||
|
|
||||||
|
|||||||
@@ -12,9 +12,10 @@ status.
|
|||||||
Empty nodes or connections are represented using np.nan.
|
Empty nodes or connections are represented using np.nan.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from typing import Tuple
|
from typing import Tuple, Dict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
import jax
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
@@ -131,6 +132,83 @@ def expand_single(nodes: NDArray, connections: NDArray, new_N: int) -> Tuple[NDA
|
|||||||
|
|
||||||
return new_nodes, new_connections
|
return new_nodes, new_connections
|
||||||
|
|
||||||
|
|
||||||
|
def analysis(nodes: NDArray, connections: NDArray, input_keys, output_keys) -> \
|
||||||
|
Tuple[Dict[int, Tuple[float, float, int, int]], Dict[Tuple[int, int], Tuple[float, bool]]]:
|
||||||
|
"""
|
||||||
|
Convert a genome from array to dict.
|
||||||
|
:param nodes: (N, 5)
|
||||||
|
:param connections: (2, N, N)
|
||||||
|
:param output_keys:
|
||||||
|
:param input_keys:
|
||||||
|
:return: nodes_dict[key: (bias, response, act, agg)], connections_dict[(f_key, t_key): (weight, enabled)]
|
||||||
|
"""
|
||||||
|
# update nodes_dict
|
||||||
|
try:
|
||||||
|
nodes_dict = {}
|
||||||
|
idx2key = {}
|
||||||
|
for i, node in enumerate(nodes):
|
||||||
|
if np.isnan(node[0]):
|
||||||
|
continue
|
||||||
|
key = int(node[0])
|
||||||
|
assert key not in nodes_dict, f"Duplicate node key: {key}!"
|
||||||
|
|
||||||
|
bias = node[1] if not np.isnan(node[1]) else None
|
||||||
|
response = node[2] if not np.isnan(node[2]) else None
|
||||||
|
act = node[3] if not np.isnan(node[3]) else None
|
||||||
|
agg = node[4] if not np.isnan(node[4]) else None
|
||||||
|
nodes_dict[key] = (bias, response, act, agg)
|
||||||
|
idx2key[i] = key
|
||||||
|
|
||||||
|
# check nodes_dict
|
||||||
|
for i in input_keys:
|
||||||
|
assert i in nodes_dict, f"Input node {i} not found in nodes_dict!"
|
||||||
|
bias, response, act, agg = nodes_dict[i]
|
||||||
|
assert bias is None and response is None and act is None and agg is None, \
|
||||||
|
f"Input node {i} must has None bias, response, act, or agg!"
|
||||||
|
|
||||||
|
for o in output_keys:
|
||||||
|
assert o in nodes_dict, f"Output node {o} not found in nodes_dict!"
|
||||||
|
|
||||||
|
for k, v in nodes_dict.items():
|
||||||
|
if k not in input_keys:
|
||||||
|
bias, response, act, agg = v
|
||||||
|
assert bias is not None and response is not None and act is not None and agg is not None, \
|
||||||
|
f"Normal node {k} must has non-None bias, response, act, or agg!"
|
||||||
|
|
||||||
|
# update connections
|
||||||
|
connections_dict = {}
|
||||||
|
for i in range(connections.shape[1]):
|
||||||
|
for j in range(connections.shape[2]):
|
||||||
|
if np.isnan(connections[0, i, j]) and np.isnan(connections[1, i, j]):
|
||||||
|
continue
|
||||||
|
assert i in idx2key, f"Node index {i} not found in idx2key:{idx2key}!"
|
||||||
|
assert j in idx2key, f"Node index {j} not found in idx2key:{idx2key}!"
|
||||||
|
key = (idx2key[i], idx2key[j])
|
||||||
|
|
||||||
|
weight = connections[0, i, j] if not np.isnan(connections[0, i, j]) else None
|
||||||
|
enabled = (connections[1, i, j] == 1) if not np.isnan(connections[1, i, j]) else None
|
||||||
|
|
||||||
|
assert weight is not None, f"Connection {key} must has non-None weight!"
|
||||||
|
assert enabled is not None, f"Connection {key} must has non-None enabled!"
|
||||||
|
connections_dict[key] = (weight, enabled)
|
||||||
|
|
||||||
|
return nodes_dict, connections_dict
|
||||||
|
except AssertionError:
|
||||||
|
print(nodes)
|
||||||
|
print(connections)
|
||||||
|
raise AssertionError
|
||||||
|
|
||||||
|
|
||||||
|
def pop_analysis(pop_nodes, pop_connections, input_keys, output_keys):
|
||||||
|
pop_nodes, pop_connections = jax.device_get((pop_nodes, pop_connections))
|
||||||
|
res = []
|
||||||
|
for nodes, connections in zip(pop_nodes, pop_connections):
|
||||||
|
res.append(analysis(nodes, connections, input_keys, output_keys))
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
def add_node(new_node_key: int, nodes: Array, connections: Array,
|
def add_node(new_node_key: int, nodes: Array, connections: Array,
|
||||||
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]:
|
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]:
|
||||||
@@ -158,11 +236,12 @@ def delete_node_by_idx(idx: int, nodes: Array, connections: Array) -> Tuple[Arra
|
|||||||
"""
|
"""
|
||||||
delete a node from the genome. only delete the node, regardless of connections.
|
delete a node from the genome. only delete the node, regardless of connections.
|
||||||
"""
|
"""
|
||||||
node_keys = nodes[:, 0]
|
# node_keys = nodes[:, 0]
|
||||||
|
nodes = nodes.at[idx].set(EMPTY_NODE)
|
||||||
# move the last node to the deleted node's position
|
# move the last node to the deleted node's position
|
||||||
last_real_idx = fetch_last(~jnp.isnan(node_keys))
|
# last_real_idx = fetch_last(~jnp.isnan(node_keys))
|
||||||
nodes = nodes.at[idx].set(nodes[last_real_idx])
|
# nodes = nodes.at[idx].set(nodes[last_real_idx])
|
||||||
nodes = nodes.at[last_real_idx].set(EMPTY_NODE)
|
# nodes = nodes.at[last_real_idx].set(EMPTY_NODE)
|
||||||
return nodes, connections
|
return nodes, connections
|
||||||
|
|
||||||
|
|
||||||
@@ -206,7 +285,3 @@ def delete_connection_by_idx(from_idx: int, to_idx: int, nodes: Array, connectio
|
|||||||
"""
|
"""
|
||||||
connections = connections.at[:, from_idx, to_idx].set(np.nan)
|
connections = connections.at[:, from_idx, to_idx].set(np.nan)
|
||||||
return nodes, connections
|
return nodes, connections
|
||||||
|
|
||||||
# if __name__ == '__main__':
|
|
||||||
# pop_nodes, pop_connections, input_keys, output_keys = initialize_genomes(100, 5, 2, 1)
|
|
||||||
# print(pop_nodes, pop_connections)
|
|
||||||
|
|||||||
@@ -148,7 +148,9 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
|
|||||||
check_cycles(nodes, connections, 0, 3) -> False
|
check_cycles(nodes, connections, 0, 3) -> False
|
||||||
check_cycles(nodes, connections, 1, 0) -> False
|
check_cycles(nodes, connections, 1, 0) -> False
|
||||||
"""
|
"""
|
||||||
connections_enable = connections[1, :, :] == 1
|
# connections_enable = connections[0, :, :] == 1
|
||||||
|
connections_enable = ~jnp.isnan(connections[0, :, :])
|
||||||
|
|
||||||
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
|
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
|
||||||
nodes_visited = jnp.full(nodes.shape[0], False)
|
nodes_visited = jnp.full(nodes.shape[0], False)
|
||||||
nodes_visited = nodes_visited.at[to_idx].set(True)
|
nodes_visited = nodes_visited.at[to_idx].set(True)
|
||||||
|
|||||||
@@ -7,7 +7,9 @@ import numpy as np
|
|||||||
from .species import SpeciesController
|
from .species import SpeciesController
|
||||||
from .genome import create_initialize_function, create_mutate_function, create_forward_function
|
from .genome import create_initialize_function, create_mutate_function, create_forward_function
|
||||||
from .genome import batch_crossover
|
from .genome import batch_crossover
|
||||||
|
from .genome.crossover import crossover
|
||||||
from .genome import expand, expand_single
|
from .genome import expand, expand_single
|
||||||
|
from algorithms.neat.genome.genome import pop_analysis, analysis
|
||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
@@ -51,12 +53,22 @@ class Pipeline:
|
|||||||
def tell(self, fitnesses):
|
def tell(self, fitnesses):
|
||||||
self.generation += 1
|
self.generation += 1
|
||||||
|
|
||||||
|
for i, f in enumerate(fitnesses):
|
||||||
|
if np.isnan(f):
|
||||||
|
print("fuck!!!!!!!!!!!!!!")
|
||||||
|
error_nodes, error_connections = self.pop_nodes[i], self.pop_connections[i]
|
||||||
|
np.save('error_nodes.npy', error_nodes)
|
||||||
|
np.save('error_connections.npy', error_connections)
|
||||||
|
assert False
|
||||||
|
|
||||||
self.species_controller.update_species_fitnesses(fitnesses)
|
self.species_controller.update_species_fitnesses(fitnesses)
|
||||||
|
|
||||||
crossover_pair = self.species_controller.reproduce(self.generation)
|
crossover_pair = self.species_controller.reproduce(self.generation)
|
||||||
|
|
||||||
self.update_next_generation(crossover_pair)
|
self.update_next_generation(crossover_pair)
|
||||||
|
|
||||||
|
# print(pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx))
|
||||||
|
|
||||||
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation)
|
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation)
|
||||||
|
|
||||||
self.expand()
|
self.expand()
|
||||||
@@ -103,16 +115,22 @@ class Pipeline:
|
|||||||
crossover_rand_keys = jax.random.split(k1, self.pop_size)
|
crossover_rand_keys = jax.random.split(k1, self.pop_size)
|
||||||
# npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
|
# npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
|
||||||
npn, npc = crossover_wrapper(crossover_rand_keys, wpn, wpc, lpn, lpc)
|
npn, npc = crossover_wrapper(crossover_rand_keys, wpn, wpc, lpn, lpc)
|
||||||
|
# print(pop_analysis(npn, npc, self.input_idx, self.output_idx))
|
||||||
|
|
||||||
# mutate
|
# mutate
|
||||||
new_node_keys = np.array(self.fetch_new_node_keys())
|
new_node_keys = np.array(self.fetch_new_node_keys())
|
||||||
mutate_rand_keys = jax.random.split(k2, self.pop_size)
|
mutate_rand_keys = jax.random.split(k2, self.pop_size)
|
||||||
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys)
|
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
|
||||||
m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc)
|
m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc)
|
||||||
|
|
||||||
|
# print(pop_analysis(m_npn, m_npc, self.input_idx, self.output_idx))
|
||||||
|
|
||||||
# elitism don't mutate
|
# elitism don't mutate
|
||||||
# (pop_size, ) to (pop_size, 1, 1)
|
# (pop_size, ) to (pop_size, 1, 1)
|
||||||
self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn)
|
self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn)
|
||||||
# (pop_size, ) to (pop_size, 1, 1, 1)
|
# (pop_size, ) to (pop_size, 1, 1, 1)
|
||||||
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
|
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
|
||||||
|
# print(pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx))
|
||||||
|
|
||||||
# recycle unused node keys
|
# recycle unused node keys
|
||||||
unused = []
|
unused = []
|
||||||
@@ -138,8 +156,8 @@ class Pipeline:
|
|||||||
self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N)
|
self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N)
|
||||||
|
|
||||||
# don't forget to expand representation genome in species
|
# don't forget to expand representation genome in species
|
||||||
for s in self.species_controller.species:
|
for s in self.species_controller.species.values():
|
||||||
s.representative = expand(*s.representative, self.N)
|
s.representative = expand_single(*s.representative, self.N)
|
||||||
|
|
||||||
def fetch_new_node_keys(self):
|
def fetch_new_node_keys(self):
|
||||||
# if remain unused keys are not enough, create new keys
|
# if remain unused keys are not enough, create new keys
|
||||||
@@ -164,6 +182,19 @@ class Pipeline:
|
|||||||
print(f"Generation: {self.generation}",
|
print(f"Generation: {self.generation}",
|
||||||
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")
|
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")
|
||||||
|
|
||||||
|
# def crossover_wrapper(self, crossover_rand_keys, wpn, wpc, lpn, lpc):
|
||||||
|
# pop_nodes, pop_connections = [], []
|
||||||
|
# for randkey, wn, wc, ln, lc in zip(crossover_rand_keys, wpn, wpc, lpn, lpc):
|
||||||
|
# new_nodes, new_connections = crossover(randkey, wn, wc, ln, lc)
|
||||||
|
# pop_nodes.append(new_nodes)
|
||||||
|
# pop_connections.append(new_connections)
|
||||||
|
# try:
|
||||||
|
# print(analysis(new_nodes, new_connections, self.input_idx, self.output_idx))
|
||||||
|
# except AssertionError:
|
||||||
|
# new_nodes, new_connections = crossover(randkey, wn, wc, ln, lc)
|
||||||
|
# return np.stack(pop_nodes), np.stack(pop_connections)
|
||||||
|
|
||||||
|
# return batch_crossover(*args)
|
||||||
|
|
||||||
def crossover_wrapper(*args):
|
def crossover_wrapper(*args):
|
||||||
return batch_crossover(*args)
|
return batch_crossover(*args)
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ class SpeciesController:
|
|||||||
for sid, species in self.species.items():
|
for sid, species in self.species.items():
|
||||||
# calculate the distance between the representative and the population
|
# calculate the distance between the representative and the population
|
||||||
r_nodes, r_connections = species.representative
|
r_nodes, r_connections = species.representative
|
||||||
distances = self.o2m_distance_func(r_nodes, r_connections, pop_nodes, pop_connections)
|
distances = self.o2m_distance_wrapper(r_nodes, r_connections, pop_nodes, pop_connections)
|
||||||
distances = jax.device_get(distances) # fetch the data from gpu
|
distances = jax.device_get(distances) # fetch the data from gpu
|
||||||
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
|
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
|
||||||
|
|
||||||
@@ -82,7 +82,7 @@ class SpeciesController:
|
|||||||
rid_list = [new_representatives[sid] for sid in previous_species_list]
|
rid_list = [new_representatives[sid] for sid in previous_species_list]
|
||||||
res_pop_distance = [
|
res_pop_distance = [
|
||||||
jax.device_get(
|
jax.device_get(
|
||||||
self.o2m_distance_func(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
|
self.o2m_distance_wrapper(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
|
||||||
)
|
)
|
||||||
for rid in rid_list
|
for rid in rid_list
|
||||||
]
|
]
|
||||||
@@ -110,7 +110,7 @@ class SpeciesController:
|
|||||||
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
|
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
|
||||||
|
|
||||||
distances = [
|
distances = [
|
||||||
self.o2o_distance_func(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
|
self.o2o_distance_wrapper(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
|
||||||
for r in rid
|
for r in rid
|
||||||
]
|
]
|
||||||
distances = np.array(distances)
|
distances = np.array(distances)
|
||||||
@@ -267,6 +267,36 @@ class SpeciesController:
|
|||||||
|
|
||||||
return crossover_pair
|
return crossover_pair
|
||||||
|
|
||||||
|
def o2m_distance_wrapper(self, r_nodes, r_connections, pop_nodes, pop_connections):
|
||||||
|
# distances = self.o2m_distance_func(r_nodes, r_connections, pop_nodes, pop_connections)
|
||||||
|
# for d in distances:
|
||||||
|
# if np.isnan(d):
|
||||||
|
# print("fuck!!!!!!!!!!!!!!")
|
||||||
|
# print(distances)
|
||||||
|
# assert False
|
||||||
|
# return distances
|
||||||
|
distances = []
|
||||||
|
for nodes, connections in zip(pop_nodes, pop_connections):
|
||||||
|
d = self.o2o_distance_func(r_nodes, r_connections, nodes, connections)
|
||||||
|
if np.isnan(d) or d > 20:
|
||||||
|
np.save("too_large_distance_r_nodes.npy", r_nodes)
|
||||||
|
np.save("too_large_distance_r_connections.npy", r_connections)
|
||||||
|
np.save("too_large_distance_nodes", nodes)
|
||||||
|
np.save("too_large_distance_connections.npy", connections)
|
||||||
|
d = self.o2o_distance_func(r_nodes, r_connections, nodes, connections)
|
||||||
|
assert False
|
||||||
|
distances.append(d)
|
||||||
|
distances = np.stack(distances, axis=0)
|
||||||
|
# print(distances)
|
||||||
|
return distances
|
||||||
|
|
||||||
|
def o2o_distance_wrapper(self, *keys):
|
||||||
|
d = self.o2o_distance_func(*keys)
|
||||||
|
if np.isnan(d):
|
||||||
|
print("fuck!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||||
|
assert False
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
|
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
|
||||||
"""
|
"""
|
||||||
|
|||||||
24
examples/error_forward_fix.py
Normal file
24
examples/error_forward_fix.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
import numpy as np
|
||||||
|
from jax import numpy as jnp
|
||||||
|
|
||||||
|
from algorithms.neat.genome.genome import analysis
|
||||||
|
from algorithms.neat.genome import create_forward_function
|
||||||
|
|
||||||
|
|
||||||
|
error_nodes = np.load('error_nodes.npy')
|
||||||
|
error_connections = np.load('error_connections.npy')
|
||||||
|
|
||||||
|
node_dict, connection_dict = analysis(error_nodes, error_connections, np.array([0, 1]), np.array([2, ]))
|
||||||
|
print(node_dict, connection_dict, sep='\n')
|
||||||
|
|
||||||
|
N = error_nodes.shape[0]
|
||||||
|
|
||||||
|
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
|
||||||
|
|
||||||
|
func = create_forward_function(error_nodes, error_connections, N, jnp.array([0, 1]), jnp.array([2, ]),
|
||||||
|
batch=True, debug=True)
|
||||||
|
out = func(np.array([1, 0]))
|
||||||
|
|
||||||
|
print(error_nodes)
|
||||||
|
print(error_connections)
|
||||||
|
print(out)
|
||||||
11
examples/fix_too_large_distance.py
Normal file
11
examples/fix_too_large_distance.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
import numpy as np
|
||||||
|
from algorithms.neat.genome import distance
|
||||||
|
|
||||||
|
r_nodes = np.load('too_large_distance_r_nodes.npy')
|
||||||
|
r_connections = np.load('too_large_distance_r_connections.npy')
|
||||||
|
nodes = np.load('too_large_distance_nodes.npy')
|
||||||
|
connections = np.load('too_large_distance_connections.npy')
|
||||||
|
|
||||||
|
d1 = distance(r_nodes, r_connections, nodes, connections)
|
||||||
|
d2 = distance(nodes, connections, r_nodes, r_connections)
|
||||||
|
print(d1, d2)
|
||||||
@@ -10,7 +10,7 @@
|
|||||||
"fitness_criterion": "max",
|
"fitness_criterion": "max",
|
||||||
"fitness_threshold": 3,
|
"fitness_threshold": 3,
|
||||||
"generation_limit": 100,
|
"generation_limit": 100,
|
||||||
"pop_size": 20,
|
"pop_size": 100,
|
||||||
"reset_on_extinction": "False"
|
"reset_on_extinction": "False"
|
||||||
},
|
},
|
||||||
"gene": {
|
"gene": {
|
||||||
@@ -73,7 +73,7 @@
|
|||||||
"node_delete_prob": 0.2
|
"node_delete_prob": 0.2
|
||||||
},
|
},
|
||||||
"species": {
|
"species": {
|
||||||
"compatibility_threshold": 8,
|
"compatibility_threshold": 3.5,
|
||||||
"species_fitness_func": "max",
|
"species_fitness_func": "max",
|
||||||
"max_stagnation": 20,
|
"max_stagnation": 20,
|
||||||
"species_elitism": 2,
|
"species_elitism": 2,
|
||||||
|
|||||||
Reference in New Issue
Block a user