add package problems

This commit is contained in:
wls2002
2023-05-10 19:30:12 +08:00
parent 097bbf6631
commit ce35b01896
10 changed files with 94 additions and 22 deletions

View File

@@ -1,34 +1,35 @@
from typing import Callable, List
from functools import partial from functools import partial
import jax
import numpy as np
from utils import Configer from utils import Configer
from algorithms.neat import Pipeline from algorithms.neat import Pipeline
from time_utils import using_cprofile from time_utils import using_cprofile
from problems import Sin, Xor
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]])
def evaluate(forward_func: Callable) -> List[float]: # xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
""" # xor_outputs = np.array([[0], [1], [1], [0]])
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size) #
:return: #
""" # def evaluate(forward_func: Callable) -> List[float]:
outs = forward_func(xor_inputs) # """
outs = jax.device_get(outs) # :param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) # :return:
return fitnesses.tolist() # returns a list # """
# outs = forward_func(xor_inputs)
# outs = jax.device_get(outs)
# fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
# return fitnesses.tolist() # returns a list
# @using_cprofile # @using_cprofile
@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
def main(): def main():
config = Configer.load_config() config = Configer.load_config()
# problem = Xor()
problem = Sin()
problem.refactor_config(config)
pipeline = Pipeline(config, seed=11454) pipeline = Pipeline(config, seed=11454)
pipeline.auto_run(evaluate) pipeline.auto_run(problem.evaluate)
if __name__ == '__main__': if __name__ == '__main__':

3
problems/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .problem import Problem
from .function_fitting import *
from .gym import *

View File

@@ -0,0 +1,3 @@
from .function_fitting_problem import FunctionFittingProblem
from .xor import *
from .sin import *

View File

@@ -0,0 +1,22 @@
import numpy as np
import jax
from problems import Problem
class FunctionFittingProblem(Problem):
def __init__(self, num_inputs, num_outputs, batch, inputs, target, loss='MSE'):
self.forward_way = 'pop_batch'
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.batch = batch
self.inputs = inputs
self.target = target
self.loss = loss
super().__init__(self.forward_way, self.num_inputs, self.num_outputs, self.batch)
def evaluate(self, batch_forward_func):
out = batch_forward_func(self.inputs)
out = jax.device_get(out)
fitnesses = 1 - np.mean((self.target - out) ** 2, axis=(1, 2))
return fitnesses.tolist()

View File

@@ -0,0 +1,14 @@
import numpy as np
from . import FunctionFittingProblem
class Sin(FunctionFittingProblem):
def __init__(self, size=100):
self.num_inputs = 1
self.num_outputs = 1
self.batch = size
self.inputs = np.linspace(0, np.pi, self.batch)[:, None]
self.target = np.sin(self.inputs)
print(self.inputs, self.target)
super().__init__(self.num_inputs, self.num_outputs, self.batch, self.inputs, self.target)

View File

@@ -0,0 +1,13 @@
import numpy as np
from . import FunctionFittingProblem
class Xor(FunctionFittingProblem):
def __init__(self):
self.num_inputs = 2
self.num_outputs = 1
self.batch = 4
self.inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
self.target = np.array([[0], [1], [1], [0]], dtype=np.float32)
super().__init__(self.num_inputs, self.num_outputs, self.batch, self.inputs, self.target)

0
problems/gym/__init__.py Normal file
View File

View File

15
problems/problem.py Normal file
View File

@@ -0,0 +1,15 @@
class Problem:
def __init__(self, forward_way, num_inputs, num_outputs, batch):
self.forward_way = forward_way
self.batch = batch
self.num_inputs = num_inputs
self.num_outputs = num_outputs
def refactor_config(self, config):
config.basic.forward_way = self.forward_way
config.basic.num_inputs = self.num_inputs
config.basic.num_outputs = self.num_outputs
config.basic.problem_batch = self.batch
def evaluate(self, batch_forward_func):
pass

View File

@@ -5,14 +5,15 @@
"problem_batch": 4, "problem_batch": 4,
"init_maximum_nodes": 10, "init_maximum_nodes": 10,
"expands_coe": 2, "expands_coe": 2,
"pre_compile_times": 3 "pre_compile_times": 3,
"forward_way": "pop_batch"
}, },
"neat": { "neat": {
"population": { "population": {
"fitness_criterion": "max", "fitness_criterion": "max",
"fitness_threshold": 76, "fitness_threshold": 76,
"generation_limit": 100, "generation_limit": 100,
"pop_size": 2000, "pop_size": 1000,
"reset_on_extinction": "False" "reset_on_extinction": "False"
}, },
"gene": { "gene": {
@@ -56,9 +57,9 @@
"compatibility_weight_coefficient": 0.5, "compatibility_weight_coefficient": 0.5,
"single_structural_mutation": "False", "single_structural_mutation": "False",
"conn_add_prob": 0.5, "conn_add_prob": 0.5,
"conn_delete_prob": 0.5, "conn_delete_prob": 0,
"node_add_prob": 0.2, "node_add_prob": 0.1,
"node_delete_prob": 0.2 "node_delete_prob": 0
}, },
"species": { "species": {
"compatibility_threshold": 3, "compatibility_threshold": 3,