precompile jax.random.split
This commit is contained in:
@@ -3,6 +3,7 @@ Lowers, compiles, and creates functions used in the NEAT pipeline.
|
|||||||
"""
|
"""
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
import jax.random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from jax import jit, vmap
|
from jax import jit, vmap
|
||||||
|
|
||||||
@@ -113,6 +114,12 @@ class FunctionFactory:
|
|||||||
self.compile_topological_sort(n)
|
self.compile_topological_sort(n)
|
||||||
self.compile_pop_batch_forward(n)
|
self.compile_pop_batch_forward(n)
|
||||||
n = int(self.expand_coe * n)
|
n = int(self.expand_coe * n)
|
||||||
|
|
||||||
|
# precompile other functions used in jax
|
||||||
|
key = jax.random.PRNGKey(0)
|
||||||
|
_ = jax.random.split(key, 2)
|
||||||
|
_ = jax.random.split(key, self.pop_size * 2)
|
||||||
|
|
||||||
print("end precompile")
|
print("end precompile")
|
||||||
|
|
||||||
def create_mutate_with_args(self):
|
def create_mutate_with_args(self):
|
||||||
|
|||||||
@@ -5,8 +5,9 @@ import jax
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .species import SpeciesController
|
from .species import SpeciesController
|
||||||
from .genome import expand, expand_single, create_forward_function
|
from .genome import expand, expand_single
|
||||||
from .function_factory import FunctionFactory
|
from .function_factory import FunctionFactory
|
||||||
|
from examples.time_utils import using_cprofile
|
||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
@@ -15,12 +16,11 @@ class Pipeline:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, seed=42):
|
def __init__(self, config, seed=42):
|
||||||
|
self.function_factory = FunctionFactory(config)
|
||||||
self.randkey = jax.random.PRNGKey(seed)
|
self.randkey = jax.random.PRNGKey(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.function_factory = FunctionFactory(config)
|
|
||||||
|
|
||||||
self.N = config.basic.init_maximum_nodes
|
self.N = config.basic.init_maximum_nodes
|
||||||
self.expand_coe = config.basic.expands_coe
|
self.expand_coe = config.basic.expands_coe
|
||||||
self.pop_size = config.neat.population.pop_size
|
self.pop_size = config.neat.population.pop_size
|
||||||
@@ -45,11 +45,6 @@ class Pipeline:
|
|||||||
"""
|
"""
|
||||||
return self.function_factory.ask(self.pop_nodes, self.pop_connections)
|
return self.function_factory.ask(self.pop_nodes, self.pop_connections)
|
||||||
|
|
||||||
#
|
|
||||||
# func = create_forward_function(self.pop_nodes, self.pop_connections, self.N, self.input_idx, self.output_idx,
|
|
||||||
# batch=batch)
|
|
||||||
# return func
|
|
||||||
|
|
||||||
def tell(self, fitnesses):
|
def tell(self, fitnesses):
|
||||||
|
|
||||||
self.generation += 1
|
self.generation += 1
|
||||||
@@ -80,6 +75,7 @@ class Pipeline:
|
|||||||
self.tell(fitnesses)
|
self.tell(fitnesses)
|
||||||
print("Generation limit reached!")
|
print("Generation limit reached!")
|
||||||
|
|
||||||
|
# @using_cprofile
|
||||||
def update_next_generation(self, crossover_pair: List[Union[int, Tuple[int, int]]]) -> None:
|
def update_next_generation(self, crossover_pair: List[Union[int, Tuple[int, int]]]) -> None:
|
||||||
"""
|
"""
|
||||||
create the next generation
|
create the next generation
|
||||||
@@ -87,7 +83,7 @@ class Pipeline:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
assert self.pop_nodes.shape[0] == self.pop_size
|
assert self.pop_nodes.shape[0] == self.pop_size
|
||||||
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
|
k, self.randkey = jax.random.split(self.randkey, 2)
|
||||||
|
|
||||||
# crossover
|
# crossover
|
||||||
# prepare elitism mask and crossover pair
|
# prepare elitism mask and crossover pair
|
||||||
@@ -99,7 +95,10 @@ class Pipeline:
|
|||||||
crossover_pair[i] = (pair, pair)
|
crossover_pair[i] = (pair, pair)
|
||||||
crossover_pair = np.array(crossover_pair)
|
crossover_pair = np.array(crossover_pair)
|
||||||
|
|
||||||
crossover_rand_keys = jax.random.split(k1, self.pop_size)
|
total_keys = jax.random.split(k, self.pop_size * 2)
|
||||||
|
crossover_rand_keys = total_keys[:self.pop_size, :]
|
||||||
|
mutate_rand_keys = total_keys[self.pop_size:, :]
|
||||||
|
|
||||||
# batch crossover
|
# batch crossover
|
||||||
wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes
|
wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes
|
||||||
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
|
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
|
||||||
@@ -109,7 +108,6 @@ class Pipeline:
|
|||||||
lpc) # new pop nodes, new pop connections
|
lpc) # new pop nodes, new pop connections
|
||||||
|
|
||||||
# mutate
|
# mutate
|
||||||
mutate_rand_keys = jax.random.split(k2, self.pop_size)
|
|
||||||
new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size)
|
new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size)
|
||||||
|
|
||||||
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
|
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ class SpeciesController:
|
|||||||
new_members[species_id].append(i)
|
new_members[species_id].append(i)
|
||||||
unspeciated[i] = False
|
unspeciated[i] = False
|
||||||
|
|
||||||
# Second, slowly match the lonely population to new-created species.
|
# Second, slowly match the lonely population to new-created species.s
|
||||||
# lonely genome is proved to be not compatible with any previous species, so they only need to be compared with
|
# lonely genome is proved to be not compatible with any previous species, so they only need to be compared with
|
||||||
# the new representatives.
|
# the new representatives.
|
||||||
for i in range(pop_nodes.shape[0]):
|
for i in range(pop_nodes.shape[0]):
|
||||||
|
|||||||
@@ -3,9 +3,9 @@
|
|||||||
"num_inputs": 2,
|
"num_inputs": 2,
|
||||||
"num_outputs": 1,
|
"num_outputs": 1,
|
||||||
"problem_batch": 4,
|
"problem_batch": 4,
|
||||||
"init_maximum_nodes": 20,
|
"init_maximum_nodes": 10,
|
||||||
"expands_coe": 2,
|
"expands_coe": 2,
|
||||||
"pre_compile_times": 0
|
"pre_compile_times": 2
|
||||||
},
|
},
|
||||||
"neat": {
|
"neat": {
|
||||||
"population": {
|
"population": {
|
||||||
|
|||||||
Reference in New Issue
Block a user