diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index 4dd8208..57848e5 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -2,6 +2,7 @@ from typing import List, Union, Tuple, Callable import time import jax +import jax.numpy as jnp import numpy as np from .species import SpeciesController @@ -16,6 +17,7 @@ class Pipeline: """ def __init__(self, config, seed=42): + self.generation_timestamp = time.time() self.randkey = jax.random.PRNGKey(seed) self.config = config @@ -32,7 +34,6 @@ class Pipeline: self.generation = 0 self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation) - self.generation_timestamp = time.time() self.best_fitness = float('-inf') def ask(self, batch: bool): @@ -87,6 +88,7 @@ class Pipeline: # crossover # prepare elitism mask and crossover pair elitism_mask = np.full(self.pop_size, False) + for i, pair in enumerate(crossover_pair): if not isinstance(pair, tuple): # elitism elitism_mask[i] = True @@ -94,13 +96,14 @@ class Pipeline: crossover_pair = np.array(crossover_pair) crossover_rand_keys = jax.random.split(k1, self.pop_size) - # batch crossover wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections - npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections + npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn, + lpc) # new pop nodes, new pop connections + npn, npc = jax.device_get(npn), jax.device_get(npc) # mutate mutate_rand_keys = jax.random.split(k2, self.pop_size) @@ -111,16 +114,12 @@ class Pipeline: # elitism don't mutate # (pop_size, ) to (pop_size, 1, 1) - def aux_function1(): - nonlocal m_npn, m_npc - m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc) - self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn) - # (pop_size, ) to (pop_size, 1, 1, 1) - self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc) + m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc) + self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn) + # (pop_size, ) to (pop_size, 1, 1, 1) + self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc) - # print(pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx)) - aux_function1() def expand(self): """ diff --git a/algorithms/neat/species.py b/algorithms/neat/species.py index 75d0d0d..1e7ca50 100644 --- a/algorithms/neat/species.py +++ b/algorithms/neat/species.py @@ -68,6 +68,7 @@ class SpeciesController: # calculate the distance between the representative and the population r_nodes, r_connections = species.representative distances = self.o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections) + distances = jax.device_get(distances) min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance new_representatives[sid] = min_idx @@ -80,7 +81,7 @@ class SpeciesController: if previous_species_list: # exist previous species rid_list = [new_representatives[sid] for sid in previous_species_list] res_pop_distance = [ - self.o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections) + jax.device_get(self.o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)) for rid in rid_list ] diff --git a/utils/default_config.json b/utils/default_config.json index 8ee9902..1881b52 100644 --- a/utils/default_config.json +++ b/utils/default_config.json @@ -2,7 +2,7 @@ "basic": { "num_inputs": 2, "num_outputs": 1, - "init_maximum_nodes": 10, + "init_maximum_nodes": 25, "expands_coe": 2 }, "neat": {