update to test in servers
This commit is contained in:
@@ -114,7 +114,7 @@ class FunctionFactory:
|
|||||||
self.compile_mutate(n)
|
self.compile_mutate(n)
|
||||||
self.compile_distance(n)
|
self.compile_distance(n)
|
||||||
self.compile_crossover(n)
|
self.compile_crossover(n)
|
||||||
self.compile_topological_sort(n)
|
self.compile_topological_sort_batch(n)
|
||||||
self.compile_pop_batch_forward(n)
|
self.compile_pop_batch_forward(n)
|
||||||
n = int(self.expand_coe * n)
|
n = int(self.expand_coe * n)
|
||||||
|
|
||||||
@@ -259,9 +259,8 @@ class FunctionFactory:
|
|||||||
|
|
||||||
def compile_topological_sort(self, n):
|
def compile_topological_sort(self, n):
|
||||||
func = self.topological_sort_with_args
|
func = self.topological_sort_with_args
|
||||||
func = vmap(func)
|
nodes_lower = np.zeros((n, 5))
|
||||||
nodes_lower = np.zeros((self.pop_size, n, 5))
|
connections_lower = np.zeros((2, n, n))
|
||||||
connections_lower = np.zeros((self.pop_size, 2, n, n))
|
|
||||||
func = jit(func).lower(nodes_lower, connections_lower).compile()
|
func = jit(func).lower(nodes_lower, connections_lower).compile()
|
||||||
self.compiled_function[('topological_sort', n)] = func
|
self.compiled_function[('topological_sort', n)] = func
|
||||||
|
|
||||||
@@ -271,6 +270,20 @@ class FunctionFactory:
|
|||||||
self.compile_topological_sort(n)
|
self.compile_topological_sort(n)
|
||||||
return self.compiled_function[key]
|
return self.compiled_function[key]
|
||||||
|
|
||||||
|
def compile_topological_sort_batch(self, n):
|
||||||
|
func = self.topological_sort_with_args
|
||||||
|
func = vmap(func)
|
||||||
|
nodes_lower = np.zeros((self.pop_size, n, 5))
|
||||||
|
connections_lower = np.zeros((self.pop_size, 2, n, n))
|
||||||
|
func = jit(func).lower(nodes_lower, connections_lower).compile()
|
||||||
|
self.compiled_function[('topological_sort_batch', n)] = func
|
||||||
|
|
||||||
|
def create_topological_sort_batch(self, n):
|
||||||
|
key = ('topological_sort_batch', n)
|
||||||
|
if key not in self.compiled_function:
|
||||||
|
self.compile_topological_sort_batch(n)
|
||||||
|
return self.compiled_function[key]
|
||||||
|
|
||||||
def create_single_forward_with_args(self):
|
def create_single_forward_with_args(self):
|
||||||
func = partial(
|
func = partial(
|
||||||
forward_single,
|
forward_single,
|
||||||
@@ -315,6 +328,18 @@ class FunctionFactory:
|
|||||||
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
|
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
|
||||||
self.compiled_function[('batch_forward', n)] = func
|
self.compiled_function[('batch_forward', n)] = func
|
||||||
|
|
||||||
|
def create_batch_forward(self, n):
|
||||||
|
key = ('batch_forward', n)
|
||||||
|
if key not in self.compiled_function:
|
||||||
|
self.compile_batch_forward(n)
|
||||||
|
if self.debug:
|
||||||
|
def debug_batch_forward(*args):
|
||||||
|
return self.compiled_function[key](*args).block_until_ready()
|
||||||
|
|
||||||
|
return debug_batch_forward
|
||||||
|
else:
|
||||||
|
return self.compiled_function[key]
|
||||||
|
|
||||||
def compile_pop_batch_forward(self, n):
|
def compile_pop_batch_forward(self, n):
|
||||||
func = self.single_forward_with_args
|
func = self.single_forward_with_args
|
||||||
func = vmap(func, in_axes=(0, None, None, None)) # batch_forward
|
func = vmap(func, in_axes=(0, None, None, None)) # batch_forward
|
||||||
@@ -340,9 +365,9 @@ class FunctionFactory:
|
|||||||
else:
|
else:
|
||||||
return self.compiled_function[key]
|
return self.compiled_function[key]
|
||||||
|
|
||||||
def ask(self, pop_nodes, pop_connections):
|
def ask_pop_batch_forward(self, pop_nodes, pop_connections):
|
||||||
n = pop_nodes.shape[1]
|
n = pop_nodes.shape[1]
|
||||||
ts = self.create_topological_sort(n)
|
ts = self.create_topological_sort_batch(n)
|
||||||
pop_cal_seqs = ts(pop_nodes, pop_connections)
|
pop_cal_seqs = ts(pop_nodes, pop_connections)
|
||||||
|
|
||||||
forward_func = self.create_pop_batch_forward(n)
|
forward_func = self.create_pop_batch_forward(n)
|
||||||
@@ -352,9 +377,13 @@ class FunctionFactory:
|
|||||||
|
|
||||||
return debug_forward
|
return debug_forward
|
||||||
|
|
||||||
# return partial(
|
def ask_batch_forward(self, nodes, connections):
|
||||||
# forward_func,
|
n = nodes.shape[0]
|
||||||
# cal_seqs=pop_cal_seqs,
|
ts = self.create_topological_sort(n)
|
||||||
# nodes=pop_nodes,
|
cal_seqs = ts(nodes, connections)
|
||||||
# connections=pop_connections
|
forward_func = self.create_batch_forward(n)
|
||||||
# )
|
|
||||||
|
def debug_forward(inputs):
|
||||||
|
return forward_func(inputs, cal_seqs, nodes, connections)
|
||||||
|
|
||||||
|
return debug_forward
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ def clamped_act(z):
|
|||||||
|
|
||||||
@jit
|
@jit
|
||||||
def inv_act(z):
|
def inv_act(z):
|
||||||
|
z = jnp.maximum(z, 1e-7)
|
||||||
return 1 / z
|
return 1 / z
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import numpy as np
|
|||||||
from .species import SpeciesController
|
from .species import SpeciesController
|
||||||
from .genome import expand, expand_single
|
from .genome import expand, expand_single
|
||||||
from .function_factory import FunctionFactory
|
from .function_factory import FunctionFactory
|
||||||
from examples.time_utils import using_cprofile
|
|
||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
@@ -16,7 +15,9 @@ class Pipeline:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, seed=42):
|
def __init__(self, config, seed=42):
|
||||||
|
self.time_dict = {}
|
||||||
self.function_factory = FunctionFactory(config, debug=True)
|
self.function_factory = FunctionFactory(config, debug=True)
|
||||||
|
|
||||||
self.randkey = jax.random.PRNGKey(seed)
|
self.randkey = jax.random.PRNGKey(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
|
|
||||||
@@ -35,6 +36,7 @@ class Pipeline:
|
|||||||
self.species_controller.init_speciate(self.pop_nodes, self.pop_connections)
|
self.species_controller.init_speciate(self.pop_nodes, self.pop_connections)
|
||||||
|
|
||||||
self.best_fitness = float('-inf')
|
self.best_fitness = float('-inf')
|
||||||
|
self.best_genome = None
|
||||||
self.generation_timestamp = time.time()
|
self.generation_timestamp = time.time()
|
||||||
|
|
||||||
def ask(self):
|
def ask(self):
|
||||||
@@ -43,7 +45,7 @@ class Pipeline:
|
|||||||
:return:
|
:return:
|
||||||
Algorithm gives the population a forward function, then environment gives back the fitnesses.
|
Algorithm gives the population a forward function, then environment gives back the fitnesses.
|
||||||
"""
|
"""
|
||||||
return self.function_factory.ask(self.pop_nodes, self.pop_connections)
|
return self.function_factory.ask_pop_batch_forward(self.pop_nodes, self.pop_connections)
|
||||||
|
|
||||||
def tell(self, fitnesses):
|
def tell(self, fitnesses):
|
||||||
|
|
||||||
@@ -72,10 +74,14 @@ class Pipeline:
|
|||||||
assert callable(analysis), f"What the fuck you passed in? A {analysis}?"
|
assert callable(analysis), f"What the fuck you passed in? A {analysis}?"
|
||||||
analysis(fitnesses)
|
analysis(fitnesses)
|
||||||
|
|
||||||
|
if max(fitnesses) >= self.config.neat.population.fitness_threshold:
|
||||||
|
print("Fitness limit reached!")
|
||||||
|
return self.best_genome
|
||||||
|
|
||||||
self.tell(fitnesses)
|
self.tell(fitnesses)
|
||||||
print("Generation limit reached!")
|
print("Generation limit reached!")
|
||||||
|
return self.best_genome
|
||||||
|
|
||||||
# @using_cprofile
|
|
||||||
def update_next_generation(self, crossover_pair: List[Union[int, Tuple[int, int]]]) -> None:
|
def update_next_generation(self, crossover_pair: List[Union[int, Tuple[int, int]]]) -> None:
|
||||||
"""
|
"""
|
||||||
create the next generation
|
create the next generation
|
||||||
@@ -152,5 +158,10 @@ class Pipeline:
|
|||||||
cost_time = new_timestamp - self.generation_timestamp
|
cost_time = new_timestamp - self.generation_timestamp
|
||||||
self.generation_timestamp = new_timestamp
|
self.generation_timestamp = new_timestamp
|
||||||
|
|
||||||
|
max_idx = np.argmax(fitnesses)
|
||||||
|
if fitnesses[max_idx] > self.best_fitness:
|
||||||
|
self.best_fitness = fitnesses[max_idx]
|
||||||
|
self.best_genome = (self.pop_nodes[max_idx], self.pop_connections[max_idx])
|
||||||
|
|
||||||
print(f"Generation: {self.generation}",
|
print(f"Generation: {self.generation}",
|
||||||
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")
|
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from functools import partial
|
|||||||
from utils import Configer
|
from utils import Configer
|
||||||
from algorithms.neat import Pipeline
|
from algorithms.neat import Pipeline
|
||||||
from time_utils import using_cprofile
|
from time_utils import using_cprofile
|
||||||
from problems import Sin, Xor
|
from problems import Sin, Xor, DIY
|
||||||
|
|
||||||
|
|
||||||
# xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
# xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||||
@@ -25,11 +25,16 @@ from problems import Sin, Xor
|
|||||||
@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
|
@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
|
||||||
def main():
|
def main():
|
||||||
config = Configer.load_config()
|
config = Configer.load_config()
|
||||||
# problem = Xor()
|
config.neat.population.pop_size = 50
|
||||||
problem = Sin()
|
problem = Xor()
|
||||||
|
# problem = Sin()
|
||||||
|
# problem = DIY(func=lambda x: (np.sin(x) + np.exp(x) - x ** 2) / (np.cos(x) + np.sqrt(x)) - np.log(x + 1))
|
||||||
problem.refactor_config(config)
|
problem.refactor_config(config)
|
||||||
pipeline = Pipeline(config, seed=11454)
|
pipeline = Pipeline(config, seed=0)
|
||||||
pipeline.auto_run(problem.evaluate)
|
best_nodes, best_connections = pipeline.auto_run(problem.evaluate)
|
||||||
|
# print(best_nodes, best_connections)
|
||||||
|
# func = pipeline.function_factory.ask_batch_forward(best_nodes, best_connections)
|
||||||
|
# problem.print(func)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
from .function_fitting_problem import FunctionFittingProblem
|
from .function_fitting_problem import FunctionFittingProblem
|
||||||
from .xor import *
|
from .xor import *
|
||||||
from .sin import *
|
from .sin import *
|
||||||
|
from .diy import *
|
||||||
14
problems/function_fitting/diy.py
Normal file
14
problems/function_fitting/diy.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from . import FunctionFittingProblem
|
||||||
|
|
||||||
|
|
||||||
|
class DIY(FunctionFittingProblem):
|
||||||
|
def __init__(self, func, size=100):
|
||||||
|
self.num_inputs = 1
|
||||||
|
self.num_outputs = 1
|
||||||
|
self.batch = size
|
||||||
|
self.inputs = np.linspace(0, 1, self.batch)[:, None]
|
||||||
|
self.target = func(self.inputs)
|
||||||
|
print(self.inputs, self.target)
|
||||||
|
super().__init__(self.num_inputs, self.num_outputs, self.batch, self.inputs, self.target)
|
||||||
@@ -15,8 +15,25 @@ class FunctionFittingProblem(Problem):
|
|||||||
self.loss = loss
|
self.loss = loss
|
||||||
super().__init__(self.forward_way, self.num_inputs, self.num_outputs, self.batch)
|
super().__init__(self.forward_way, self.num_inputs, self.num_outputs, self.batch)
|
||||||
|
|
||||||
def evaluate(self, batch_forward_func):
|
def evaluate(self, pop_batch_forward):
|
||||||
out = batch_forward_func(self.inputs)
|
outs = pop_batch_forward(self.inputs)
|
||||||
out = jax.device_get(out)
|
outs = jax.device_get(outs)
|
||||||
fitnesses = 1 - np.mean((self.target - out) ** 2, axis=(1, 2))
|
fitnesses = -np.mean((self.target - outs) ** 2, axis=(1, 2))
|
||||||
return fitnesses.tolist()
|
return fitnesses.tolist()
|
||||||
|
|
||||||
|
def draw(self, batch_func):
|
||||||
|
outs = batch_func(self.inputs)
|
||||||
|
outs = jax.device_get(outs)
|
||||||
|
print(outs)
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
plt.xlabel('x')
|
||||||
|
plt.ylabel('y')
|
||||||
|
plt.plot(self.inputs, self.target, color='red', label='target')
|
||||||
|
plt.plot(self.inputs, outs, color='blue', label='predict')
|
||||||
|
plt.legend()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
def print(self, batch_func):
|
||||||
|
outs = batch_func(self.inputs)
|
||||||
|
outs = jax.device_get(outs)
|
||||||
|
print(outs)
|
||||||
@@ -8,7 +8,7 @@ class Sin(FunctionFittingProblem):
|
|||||||
self.num_inputs = 1
|
self.num_inputs = 1
|
||||||
self.num_outputs = 1
|
self.num_outputs = 1
|
||||||
self.batch = size
|
self.batch = size
|
||||||
self.inputs = np.linspace(0, np.pi, self.batch)[:, None]
|
self.inputs = np.linspace(0, 2 * np.pi, self.batch)[:, None]
|
||||||
self.target = np.sin(self.inputs)
|
self.target = np.sin(self.inputs)
|
||||||
print(self.inputs, self.target)
|
print(self.inputs, self.target)
|
||||||
super().__init__(self.num_inputs, self.num_outputs, self.batch, self.inputs, self.target)
|
super().__init__(self.num_inputs, self.num_outputs, self.batch, self.inputs, self.target)
|
||||||
|
|||||||
@@ -11,9 +11,9 @@
|
|||||||
"neat": {
|
"neat": {
|
||||||
"population": {
|
"population": {
|
||||||
"fitness_criterion": "max",
|
"fitness_criterion": "max",
|
||||||
"fitness_threshold": 76,
|
"fitness_threshold": -0.001,
|
||||||
"generation_limit": 100,
|
"generation_limit": 1000,
|
||||||
"pop_size": 1000,
|
"pop_size": 30,
|
||||||
"reset_on_extinction": "False"
|
"reset_on_extinction": "False"
|
||||||
},
|
},
|
||||||
"gene": {
|
"gene": {
|
||||||
@@ -33,12 +33,12 @@
|
|||||||
},
|
},
|
||||||
"activation": {
|
"activation": {
|
||||||
"default": "sigmoid",
|
"default": "sigmoid",
|
||||||
"options": ["sigmoid"],
|
"options": "sigmoid",
|
||||||
"mutate_rate": 0.1
|
"mutate_rate": 0.1
|
||||||
},
|
},
|
||||||
"aggregation": {
|
"aggregation": {
|
||||||
"default": "sum",
|
"default": "sum",
|
||||||
"options": ["sum"],
|
"options": "sum",
|
||||||
"mutate_rate": 0.1
|
"mutate_rate": 0.1
|
||||||
},
|
},
|
||||||
"weight": {
|
"weight": {
|
||||||
@@ -57,12 +57,12 @@
|
|||||||
"compatibility_weight_coefficient": 0.5,
|
"compatibility_weight_coefficient": 0.5,
|
||||||
"single_structural_mutation": "False",
|
"single_structural_mutation": "False",
|
||||||
"conn_add_prob": 0.5,
|
"conn_add_prob": 0.5,
|
||||||
"conn_delete_prob": 0,
|
"conn_delete_prob": 0.5,
|
||||||
"node_add_prob": 0.1,
|
"node_add_prob": 0.2,
|
||||||
"node_delete_prob": 0
|
"node_delete_prob": 0.2
|
||||||
},
|
},
|
||||||
"species": {
|
"species": {
|
||||||
"compatibility_threshold": 3,
|
"compatibility_threshold": 2.5,
|
||||||
"species_fitness_func": "max",
|
"species_fitness_func": "max",
|
||||||
"max_stagnation": 20,
|
"max_stagnation": 20,
|
||||||
"species_elitism": 2,
|
"species_elitism": 2,
|
||||||
|
|||||||
Reference in New Issue
Block a user