bug down! Here it can solve xor successfully!

This commit is contained in:
wls2002
2023-05-07 16:03:52 +08:00
parent d1f54022bd
commit a3b9bca866
12 changed files with 120 additions and 254 deletions

View File

@@ -1,9 +1,11 @@
from typing import List, Tuple, Dict, Union
from itertools import count
import jax
import numpy as np
from numpy.typing import NDArray
from .genome.numpy import distance
from .genome import distance
class Species(object):
@@ -45,6 +47,9 @@ class SpeciesController:
self.species_idxer = count(0)
self.species: Dict[int, Species] = {} # species_id -> species
self.distance = distance
self.o2m_distance = jax.vmap(distance, in_axes=(None, None, 0, 0))
def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None:
"""
:param pop_nodes:
@@ -62,7 +67,7 @@ class SpeciesController:
for sid, species in self.species.items():
# calculate the distance between the representative and the population
r_nodes, r_connections = species.representative
distances = o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections)
distances = self.o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections)
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
new_representatives[sid] = min_idx
@@ -75,7 +80,7 @@ class SpeciesController:
if previous_species_list: # exist previous species
rid_list = [new_representatives[sid] for sid in previous_species_list]
res_pop_distance = [
o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
self.o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
for rid in rid_list
]
@@ -102,7 +107,7 @@ class SpeciesController:
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
distances = [
distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
self.distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
for r in rid
]
distances = np.array(distances)
@@ -314,16 +319,3 @@ def sort_element_with_fitnesses(members: List[int], fitnesses: List[float]) \
sorted_members = [item[0] for item in sorted_combined]
sorted_fitnesses = [item[1] for item in sorted_combined]
return sorted_members, sorted_fitnesses
def o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections):
distances = []
for nodes, connections in zip(pop_nodes, pop_connections):
d = distance(r_nodes, r_connections, nodes, connections)
if d < 0:
d = distance(r_nodes, r_connections, nodes, connections)
print(d)
assert False
distances.append(d)
distances = np.stack(distances, axis=0)
return distances