move o2o_distance and o2m_distance to pipelines

This commit is contained in:
wls2002
2023-05-08 01:19:45 +08:00
parent c705b5cfe2
commit 497d89fc69
3 changed files with 55 additions and 23 deletions

View File

@@ -6,9 +6,9 @@ import jax.numpy as jnp
import numpy as np
from .species import SpeciesController
from .genome import create_initialize_function, create_mutate_function, create_forward_function
from .genome import create_crossover_function
from .genome import expand, expand_single
from .genome import create_initialize_function, create_mutate_function, create_forward_function, \
create_distance_function, create_crossover_function
class Pipeline:
@@ -30,9 +30,12 @@ class Pipeline:
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
self.mutate_func = create_mutate_function(config, self.input_idx, self.output_idx, batch=True)
self.crossover_func = create_crossover_function(batch=True)
self.o2o_distance = create_distance_function(self.config, type='o2o')
self.o2m_distance = create_distance_function(self.config, type='o2m')
self.generation = 0
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation)
self.species_controller.speciate(self.pop_nodes, self.pop_connections,
self.generation, self.o2o_distance, self.o2m_distance)
self.best_fitness = float('-inf')
@@ -57,7 +60,8 @@ class Pipeline:
self.update_next_generation(crossover_pair)
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation)
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation,
self.o2o_distance, self.o2m_distance)
self.expand()
@@ -119,8 +123,6 @@ class Pipeline:
# (pop_size, ) to (pop_size, 1, 1, 1)
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
def expand(self):
"""
Expand the population if needed.

View File

@@ -1,13 +1,10 @@
from typing import List, Tuple, Dict, Union
from typing import List, Tuple, Dict, Union, Callable
from itertools import count
import jax
import numpy as np
from numpy.typing import NDArray
from .genome import create_distance_function
class Species(object):
def __init__(self, key, generation):
@@ -47,14 +44,14 @@ class SpeciesController:
self.species_idxer = count(0)
self.species: Dict[int, Species] = {} # species_id -> species
self.o2o_distance = create_distance_function(self.config, type='o2o')
self.o2m_distance = create_distance_function(self.config, type='o2m')
def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None:
def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int,
o2o_distance: Callable, o2m_distance: Callable) -> None:
"""
:param pop_nodes:
:param pop_connections:
:param generation: use to flag the created time of new species
:param o2o_distance: distance function for one-to-one comparison
:param o2m_distance: distance function for one-to-many comparison
:return:
"""
unspeciated = np.full((pop_nodes.shape[0],), True, dtype=bool)
@@ -67,7 +64,7 @@ class SpeciesController:
for sid, species in self.species.items():
# calculate the distance between the representative and the population
r_nodes, r_connections = species.representative
distances = self.o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections)
distances = o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections)
distances = jax.device_get(distances)
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
@@ -81,7 +78,7 @@ class SpeciesController:
if previous_species_list: # exist previous species
rid_list = [new_representatives[sid] for sid in previous_species_list]
res_pop_distance = [
jax.device_get(self.o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections))
jax.device_get(o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections))
for rid in rid_list
]
@@ -107,7 +104,7 @@ class SpeciesController:
# the representatives of new species
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
distances = [
self.o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
for r in rid
]
distances = np.array(distances)

View File

@@ -4,11 +4,44 @@ import numpy as np
from jax import random
from jax import vmap, jit
seed = jax.random.PRNGKey(42)
seed, *subkeys = random.split(seed, 3)
from examples.time_utils import using_cprofile
c = random.split(seed, 1)
print(seed, subkeys)
print(c)
def func(x, y):
"""
:param x: (100, )
:param y: (100,
:return:
"""
return x * y
# @using_cprofile
def main():
key = jax.random.PRNGKey(42)
x1, y1 = jax.random.normal(key, shape=(100,)), jax.random.normal(key, shape=(100,))
jit_func = jit(func)
z = jit_func(x1, y1)
print(z)
jit_lower_func = jit(func).lower(x1, y1).compile()
print(type(jit_lower_func))
import pickle
with open('jit_function.pkl', 'wb') as f:
pickle.dump(jit_lower_func, f)
new_jit_lower_func = pickle.load(open('jit_function.pkl', 'rb'))
print(jit_lower_func(x1, y1))
print(new_jit_lower_func(x1, y1))
# x2, y2 = jax.random.normal(key, shape=(200,)), jax.random.normal(key, shape=(200,))
# print(jit_lower_func(x2, y2))
if __name__ == '__main__':
main()