update to test in servers
This commit is contained in:
@@ -114,7 +114,7 @@ class FunctionFactory:
|
||||
self.compile_mutate(n)
|
||||
self.compile_distance(n)
|
||||
self.compile_crossover(n)
|
||||
self.compile_topological_sort(n)
|
||||
self.compile_topological_sort_batch(n)
|
||||
self.compile_pop_batch_forward(n)
|
||||
n = int(self.expand_coe * n)
|
||||
|
||||
@@ -259,9 +259,8 @@ class FunctionFactory:
|
||||
|
||||
def compile_topological_sort(self, n):
|
||||
func = self.topological_sort_with_args
|
||||
func = vmap(func)
|
||||
nodes_lower = np.zeros((self.pop_size, n, 5))
|
||||
connections_lower = np.zeros((self.pop_size, 2, n, n))
|
||||
nodes_lower = np.zeros((n, 5))
|
||||
connections_lower = np.zeros((2, n, n))
|
||||
func = jit(func).lower(nodes_lower, connections_lower).compile()
|
||||
self.compiled_function[('topological_sort', n)] = func
|
||||
|
||||
@@ -271,6 +270,20 @@ class FunctionFactory:
|
||||
self.compile_topological_sort(n)
|
||||
return self.compiled_function[key]
|
||||
|
||||
def compile_topological_sort_batch(self, n):
|
||||
func = self.topological_sort_with_args
|
||||
func = vmap(func)
|
||||
nodes_lower = np.zeros((self.pop_size, n, 5))
|
||||
connections_lower = np.zeros((self.pop_size, 2, n, n))
|
||||
func = jit(func).lower(nodes_lower, connections_lower).compile()
|
||||
self.compiled_function[('topological_sort_batch', n)] = func
|
||||
|
||||
def create_topological_sort_batch(self, n):
|
||||
key = ('topological_sort_batch', n)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_topological_sort_batch(n)
|
||||
return self.compiled_function[key]
|
||||
|
||||
def create_single_forward_with_args(self):
|
||||
func = partial(
|
||||
forward_single,
|
||||
@@ -315,6 +328,18 @@ class FunctionFactory:
|
||||
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
|
||||
self.compiled_function[('batch_forward', n)] = func
|
||||
|
||||
def create_batch_forward(self, n):
|
||||
key = ('batch_forward', n)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_batch_forward(n)
|
||||
if self.debug:
|
||||
def debug_batch_forward(*args):
|
||||
return self.compiled_function[key](*args).block_until_ready()
|
||||
|
||||
return debug_batch_forward
|
||||
else:
|
||||
return self.compiled_function[key]
|
||||
|
||||
def compile_pop_batch_forward(self, n):
|
||||
func = self.single_forward_with_args
|
||||
func = vmap(func, in_axes=(0, None, None, None)) # batch_forward
|
||||
@@ -340,9 +365,9 @@ class FunctionFactory:
|
||||
else:
|
||||
return self.compiled_function[key]
|
||||
|
||||
def ask(self, pop_nodes, pop_connections):
|
||||
def ask_pop_batch_forward(self, pop_nodes, pop_connections):
|
||||
n = pop_nodes.shape[1]
|
||||
ts = self.create_topological_sort(n)
|
||||
ts = self.create_topological_sort_batch(n)
|
||||
pop_cal_seqs = ts(pop_nodes, pop_connections)
|
||||
|
||||
forward_func = self.create_pop_batch_forward(n)
|
||||
@@ -352,9 +377,13 @@ class FunctionFactory:
|
||||
|
||||
return debug_forward
|
||||
|
||||
# return partial(
|
||||
# forward_func,
|
||||
# cal_seqs=pop_cal_seqs,
|
||||
# nodes=pop_nodes,
|
||||
# connections=pop_connections
|
||||
# )
|
||||
def ask_batch_forward(self, nodes, connections):
|
||||
n = nodes.shape[0]
|
||||
ts = self.create_topological_sort(n)
|
||||
cal_seqs = ts(nodes, connections)
|
||||
forward_func = self.create_batch_forward(n)
|
||||
|
||||
def debug_forward(inputs):
|
||||
return forward_func(inputs, cal_seqs, nodes, connections)
|
||||
|
||||
return debug_forward
|
||||
|
||||
@@ -68,6 +68,7 @@ def clamped_act(z):
|
||||
|
||||
@jit
|
||||
def inv_act(z):
|
||||
z = jnp.maximum(z, 1e-7)
|
||||
return 1 / z
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import numpy as np
|
||||
from .species import SpeciesController
|
||||
from .genome import expand, expand_single
|
||||
from .function_factory import FunctionFactory
|
||||
from examples.time_utils import using_cprofile
|
||||
|
||||
|
||||
class Pipeline:
|
||||
@@ -16,7 +15,9 @@ class Pipeline:
|
||||
"""
|
||||
|
||||
def __init__(self, config, seed=42):
|
||||
self.time_dict = {}
|
||||
self.function_factory = FunctionFactory(config, debug=True)
|
||||
|
||||
self.randkey = jax.random.PRNGKey(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
@@ -35,6 +36,7 @@ class Pipeline:
|
||||
self.species_controller.init_speciate(self.pop_nodes, self.pop_connections)
|
||||
|
||||
self.best_fitness = float('-inf')
|
||||
self.best_genome = None
|
||||
self.generation_timestamp = time.time()
|
||||
|
||||
def ask(self):
|
||||
@@ -43,7 +45,7 @@ class Pipeline:
|
||||
:return:
|
||||
Algorithm gives the population a forward function, then environment gives back the fitnesses.
|
||||
"""
|
||||
return self.function_factory.ask(self.pop_nodes, self.pop_connections)
|
||||
return self.function_factory.ask_pop_batch_forward(self.pop_nodes, self.pop_connections)
|
||||
|
||||
def tell(self, fitnesses):
|
||||
|
||||
@@ -72,10 +74,14 @@ class Pipeline:
|
||||
assert callable(analysis), f"What the fuck you passed in? A {analysis}?"
|
||||
analysis(fitnesses)
|
||||
|
||||
if max(fitnesses) >= self.config.neat.population.fitness_threshold:
|
||||
print("Fitness limit reached!")
|
||||
return self.best_genome
|
||||
|
||||
self.tell(fitnesses)
|
||||
print("Generation limit reached!")
|
||||
return self.best_genome
|
||||
|
||||
# @using_cprofile
|
||||
def update_next_generation(self, crossover_pair: List[Union[int, Tuple[int, int]]]) -> None:
|
||||
"""
|
||||
create the next generation
|
||||
@@ -152,5 +158,10 @@ class Pipeline:
|
||||
cost_time = new_timestamp - self.generation_timestamp
|
||||
self.generation_timestamp = new_timestamp
|
||||
|
||||
max_idx = np.argmax(fitnesses)
|
||||
if fitnesses[max_idx] > self.best_fitness:
|
||||
self.best_fitness = fitnesses[max_idx]
|
||||
self.best_genome = (self.pop_nodes[max_idx], self.pop_connections[max_idx])
|
||||
|
||||
print(f"Generation: {self.generation}",
|
||||
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")
|
||||
|
||||
Reference in New Issue
Block a user