From 497d89fc696fa8b8001d36a232197daf7768f79a Mon Sep 17 00:00:00 2001 From: wls2002 Date: Mon, 8 May 2023 01:19:45 +0800 Subject: [PATCH] move o2o_distance and o2m_distance to pipelines --- algorithms/neat/pipeline.py | 14 +++++++----- algorithms/neat/species.py | 19 +++++++--------- examples/jax_playground.py | 45 ++++++++++++++++++++++++++++++++----- 3 files changed, 55 insertions(+), 23 deletions(-) diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index 57848e5..b39be0a 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -6,9 +6,9 @@ import jax.numpy as jnp import numpy as np from .species import SpeciesController -from .genome import create_initialize_function, create_mutate_function, create_forward_function -from .genome import create_crossover_function from .genome import expand, expand_single +from .genome import create_initialize_function, create_mutate_function, create_forward_function, \ + create_distance_function, create_crossover_function class Pipeline: @@ -30,9 +30,12 @@ class Pipeline: self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func() self.mutate_func = create_mutate_function(config, self.input_idx, self.output_idx, batch=True) self.crossover_func = create_crossover_function(batch=True) + self.o2o_distance = create_distance_function(self.config, type='o2o') + self.o2m_distance = create_distance_function(self.config, type='o2m') self.generation = 0 - self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation) + self.species_controller.speciate(self.pop_nodes, self.pop_connections, + self.generation, self.o2o_distance, self.o2m_distance) self.best_fitness = float('-inf') @@ -57,7 +60,8 @@ class Pipeline: self.update_next_generation(crossover_pair) - self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation) + self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation, + self.o2o_distance, self.o2m_distance) self.expand() @@ -119,8 +123,6 @@ class Pipeline: # (pop_size, ) to (pop_size, 1, 1, 1) self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc) - - def expand(self): """ Expand the population if needed. diff --git a/algorithms/neat/species.py b/algorithms/neat/species.py index 1e7ca50..620de50 100644 --- a/algorithms/neat/species.py +++ b/algorithms/neat/species.py @@ -1,13 +1,10 @@ -from typing import List, Tuple, Dict, Union +from typing import List, Tuple, Dict, Union, Callable from itertools import count import jax import numpy as np from numpy.typing import NDArray -from .genome import create_distance_function - - class Species(object): def __init__(self, key, generation): @@ -47,14 +44,14 @@ class SpeciesController: self.species_idxer = count(0) self.species: Dict[int, Species] = {} # species_id -> species - self.o2o_distance = create_distance_function(self.config, type='o2o') - self.o2m_distance = create_distance_function(self.config, type='o2m') - - def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None: + def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int, + o2o_distance: Callable, o2m_distance: Callable) -> None: """ :param pop_nodes: :param pop_connections: :param generation: use to flag the created time of new species + :param o2o_distance: distance function for one-to-one comparison + :param o2m_distance: distance function for one-to-many comparison :return: """ unspeciated = np.full((pop_nodes.shape[0],), True, dtype=bool) @@ -67,7 +64,7 @@ class SpeciesController: for sid, species in self.species.items(): # calculate the distance between the representative and the population r_nodes, r_connections = species.representative - distances = self.o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections) + distances = o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections) distances = jax.device_get(distances) min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance @@ -81,7 +78,7 @@ class SpeciesController: if previous_species_list: # exist previous species rid_list = [new_representatives[sid] for sid in previous_species_list] res_pop_distance = [ - jax.device_get(self.o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)) + jax.device_get(o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)) for rid in rid_list ] @@ -107,7 +104,7 @@ class SpeciesController: # the representatives of new species sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()])) distances = [ - self.o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r]) + o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r]) for r in rid ] distances = np.array(distances) diff --git a/examples/jax_playground.py b/examples/jax_playground.py index e8f8374..2554ab9 100644 --- a/examples/jax_playground.py +++ b/examples/jax_playground.py @@ -4,11 +4,44 @@ import numpy as np from jax import random from jax import vmap, jit - -seed = jax.random.PRNGKey(42) -seed, *subkeys = random.split(seed, 3) +from examples.time_utils import using_cprofile -c = random.split(seed, 1) -print(seed, subkeys) -print(c) \ No newline at end of file +def func(x, y): + """ + :param x: (100, ) + :param y: (100, + :return: + """ + return x * y + + +# @using_cprofile +def main(): + key = jax.random.PRNGKey(42) + + x1, y1 = jax.random.normal(key, shape=(100,)), jax.random.normal(key, shape=(100,)) + + jit_func = jit(func) + + z = jit_func(x1, y1) + print(z) + + jit_lower_func = jit(func).lower(x1, y1).compile() + print(type(jit_lower_func)) + import pickle + + with open('jit_function.pkl', 'wb') as f: + pickle.dump(jit_lower_func, f) + + new_jit_lower_func = pickle.load(open('jit_function.pkl', 'rb')) + + print(jit_lower_func(x1, y1)) + print(new_jit_lower_func(x1, y1)) + + # x2, y2 = jax.random.normal(key, shape=(200,)), jax.random.normal(key, shape=(200,)) + # print(jit_lower_func(x2, y2)) + + +if __name__ == '__main__': + main()