Current Progress: After final design presentation

This commit is contained in:
wls2002
2023-06-19 15:17:56 +08:00
parent acedd67617
commit 5cbe3c14bb
34 changed files with 533 additions and 558 deletions

View File

@@ -1,45 +0,0 @@
import numpy as np
import jax
from utils import Configer
from neat import Pipeline
from neat import FunctionFactory
from problems import EnhanceLogic
import time
def evaluate(problem, func):
inputs = problem.ask_for_inputs()
pop_predict = jax.device_get(func(inputs))
# print(pop_predict)
fitnesses = []
for predict in pop_predict:
f = problem.evaluate_predict(predict)
fitnesses.append(f)
return np.array(fitnesses)
# @using_cprofile
# @partial(using_cprofile, root_abs_path='/mnt/e/neatax/', replace_pattern="/mnt/e/neat-jax/")
def main():
tic = time.time()
config = Configer.load_config()
problem = EnhanceLogic("xor", n=3)
problem.refactor_config(config)
function_factory = FunctionFactory(config)
evaluate_func = lambda func: evaluate(problem, func)
pipeline = Pipeline(config, function_factory, seed=33413)
print("start run")
pipeline.auto_run(evaluate_func)
total_time = time.time() - tic
compile_time = pipeline.function_factory.compile_time
total_it = pipeline.generation
mean_time_per_it = (total_time - compile_time) / total_it
evaluate_time = pipeline.evaluate_time
print(
f"total time: {total_time:.2f}s, compile time: {compile_time:.2f}s, real_time: {total_time - compile_time:.2f}s, evaluate time: {evaluate_time:.2f}s")
print(f"total it: {total_it}, mean time per it: {mean_time_per_it:.2f}s")
if __name__ == '__main__':
main()

View File

@@ -1,37 +0,0 @@
import jax
import numpy as np
from neat import FunctionFactory
from neat.genome.debug.tools import check_array_valid
from utils import Configer
if __name__ == '__main__':
config = Configer.load_config()
function_factory = FunctionFactory(config, debug=True)
initialize_func = function_factory.create_initialize()
pop_nodes, pop_connections, input_idx, output_idx = initialize_func()
mutate_func = function_factory.create_mutate(pop_nodes.shape[1], pop_connections.shape[1])
crossover_func = function_factory.create_crossover(pop_nodes.shape[1], pop_connections.shape[1])
key = jax.random.PRNGKey(0)
new_node_idx = 100
while True:
key, subkey = jax.random.split(key)
mutate_keys = jax.random.split(subkey, len(pop_nodes))
new_nodes = np.arange(new_node_idx, new_node_idx + len(pop_nodes))
new_node_idx += len(pop_nodes)
pop_nodes, pop_connections = mutate_func(mutate_keys, pop_nodes, pop_connections, new_nodes)
pop_nodes, pop_connections = jax.device_get([pop_nodes, pop_connections])
idx1 = np.random.permutation(len(pop_nodes))
idx2 = np.random.permutation(len(pop_nodes))
n1, c1 = pop_nodes[idx1], pop_connections[idx1]
n2, c2 = pop_nodes[idx2], pop_connections[idx2]
crossover_keys = jax.random.split(subkey, len(pop_nodes))
pop_nodes, pop_connections = crossover_func(crossover_keys, n1, c1, n2, c2)
for i in range(len(pop_nodes)):
check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx)
print(new_node_idx)

View File

@@ -1,59 +1,21 @@
from functools import partial
import jax
import jax.numpy as jnp
from jax import jit, vmap
from time import time
import numpy as np
from jax import jit
@jit
def jax_mutate(seed, x):
noise = jax.random.normal(seed, x.shape) * 0.1
return x + noise
def numpy_mutate(x):
noise = np.random.normal(size=x.shape) * 0.1
return x + noise
def jax_mutate_population(seed, pop_x):
seeds = jax.random.split(seed, len(pop_x))
func = vmap(jax_mutate, in_axes=(0, 0))
return func(seeds, pop_x)
def numpy_mutate_population(pop_x):
return np.stack([numpy_mutate(x) for x in pop_x])
def numpy_mutate_population_vmap(pop_x):
noise = np.random.normal(size=pop_x.shape) * 0.1
return pop_x + noise
from configs import Configer
from neat.pipeline_ import Pipeline
def main():
seed = jax.random.PRNGKey(0)
i = 10
while i < 200000:
pop_x = jnp.ones((i, 100, 100))
jax_pop_func = jit(jax_mutate_population).lower(seed, pop_x).compile()
config = Configer.load_config("xor.ini")
print(config)
pipeline = Pipeline(config)
tic = time()
res = jax.device_get(jax_pop_func(seed, pop_x))
jax_time = time() - tic
tic = time()
res = numpy_mutate_population(pop_x)
numpy_time = time() - tic
tic = time()
res = numpy_mutate_population_vmap(pop_x)
numpy_time_vmap = time() - tic
# print(f'POP_SIZE: {i} | JAX: {jax_time:.4f} | Numpy: {numpy_time:.4f} | Speedup: {numpy_time / jax_time:.4f}')
print(f'POP_SIZE: {i} | JAX: {jax_time:.4f} | Numpy: {numpy_time:.4f} | Numpy Vmap: {numpy_time_vmap:.4f}')
i = int(i * 1.3)
@jit
def f(x, jit_config):
return x + jit_config["bias_mutate_rate"]
if __name__ == '__main__':

2
examples/xor.ini Normal file
View File

@@ -0,0 +1,2 @@
[population]
fitness_threshold = -1e-2

View File

@@ -1,29 +1,43 @@
from neat import FunctionFactory
from utils import Configer
from neat import Pipeline
from problems import Xor
from typing import Callable, List
import time
import numpy as np
from configs import Configer
from neat import Pipeline
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
xor_outputs = np.array([[0], [1], [1], [0]])
def evaluate(forward_func: Callable) -> List[float]:
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
# print(fitnesses)
return fitnesses.tolist() # returns a list
# @using_cprofile
# @partial(using_cprofile, root_abs_path='/mnt/e/neatax/', replace_pattern="/mnt/e/neat-jax/")
def main():
tic = time.time()
config = Configer.load_config()
config = Configer.load_config("xor.ini")
print(config)
assert False
problem = Xor()
problem.refactor_config(config)
function_factory = FunctionFactory(config)
pipeline = Pipeline(config, function_factory, seed=6)
nodes, cons = pipeline.auto_run(problem.evaluate)
nodes, cons = pipeline.auto_run(evaluate)
print(nodes, cons)
total_time = time.time() - tic
compile_time = pipeline.function_factory.compile_time
total_it = pipeline.generation
mean_time_per_it = (total_time - compile_time) / total_it
evaluate_time = pipeline.evaluate_time
print(f"total time: {total_time:.2f}s, compile time: {compile_time:.2f}s, real_time: {total_time - compile_time:.2f}s, evaluate time: {evaluate_time:.2f}s")
print(
f"total time: {total_time:.2f}s, compile time: {compile_time:.2f}s, real_time: {total_time - compile_time:.2f}s, evaluate time: {evaluate_time:.2f}s")
print(f"total it: {total_it}, mean time per it: {mean_time_per_it:.2f}s")