diff --git a/examples/xor.py b/examples/xor.py index bfec783..2bd48ab 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -1,34 +1,35 @@ -from typing import Callable, List from functools import partial -import jax -import numpy as np - from utils import Configer from algorithms.neat import Pipeline from time_utils import using_cprofile - -xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) -xor_outputs = np.array([[0], [1], [1], [0]]) +from problems import Sin, Xor -def evaluate(forward_func: Callable) -> List[float]: - """ - :param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size) - :return: - """ - outs = forward_func(xor_inputs) - outs = jax.device_get(outs) - fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) - return fitnesses.tolist() # returns a list +# xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) +# xor_outputs = np.array([[0], [1], [1], [0]]) +# +# +# def evaluate(forward_func: Callable) -> List[float]: +# """ +# :param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size) +# :return: +# """ +# outs = forward_func(xor_inputs) +# outs = jax.device_get(outs) +# fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) +# return fitnesses.tolist() # returns a list # @using_cprofile @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") def main(): config = Configer.load_config() + # problem = Xor() + problem = Sin() + problem.refactor_config(config) pipeline = Pipeline(config, seed=11454) - pipeline.auto_run(evaluate) + pipeline.auto_run(problem.evaluate) if __name__ == '__main__': diff --git a/problems/__init__.py b/problems/__init__.py new file mode 100644 index 0000000..b9a4800 --- /dev/null +++ b/problems/__init__.py @@ -0,0 +1,3 @@ +from .problem import Problem +from .function_fitting import * +from .gym import * diff --git a/problems/function_fitting/__init__.py b/problems/function_fitting/__init__.py new file mode 100644 index 0000000..4d597e0 --- /dev/null +++ b/problems/function_fitting/__init__.py @@ -0,0 +1,3 @@ +from .function_fitting_problem import FunctionFittingProblem +from .xor import * +from .sin import * diff --git a/problems/function_fitting/function_fitting_problem.py b/problems/function_fitting/function_fitting_problem.py new file mode 100644 index 0000000..c06132d --- /dev/null +++ b/problems/function_fitting/function_fitting_problem.py @@ -0,0 +1,22 @@ +import numpy as np +import jax + +from problems import Problem + + +class FunctionFittingProblem(Problem): + def __init__(self, num_inputs, num_outputs, batch, inputs, target, loss='MSE'): + self.forward_way = 'pop_batch' + self.num_inputs = num_inputs + self.num_outputs = num_outputs + self.batch = batch + self.inputs = inputs + self.target = target + self.loss = loss + super().__init__(self.forward_way, self.num_inputs, self.num_outputs, self.batch) + + def evaluate(self, batch_forward_func): + out = batch_forward_func(self.inputs) + out = jax.device_get(out) + fitnesses = 1 - np.mean((self.target - out) ** 2, axis=(1, 2)) + return fitnesses.tolist() diff --git a/problems/function_fitting/sin.py b/problems/function_fitting/sin.py new file mode 100644 index 0000000..2b1ee03 --- /dev/null +++ b/problems/function_fitting/sin.py @@ -0,0 +1,14 @@ +import numpy as np + +from . import FunctionFittingProblem + + +class Sin(FunctionFittingProblem): + def __init__(self, size=100): + self.num_inputs = 1 + self.num_outputs = 1 + self.batch = size + self.inputs = np.linspace(0, np.pi, self.batch)[:, None] + self.target = np.sin(self.inputs) + print(self.inputs, self.target) + super().__init__(self.num_inputs, self.num_outputs, self.batch, self.inputs, self.target) diff --git a/problems/function_fitting/xor.py b/problems/function_fitting/xor.py new file mode 100644 index 0000000..250e04c --- /dev/null +++ b/problems/function_fitting/xor.py @@ -0,0 +1,13 @@ +import numpy as np + +from . import FunctionFittingProblem + + +class Xor(FunctionFittingProblem): + def __init__(self): + self.num_inputs = 2 + self.num_outputs = 1 + self.batch = 4 + self.inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) + self.target = np.array([[0], [1], [1], [0]], dtype=np.float32) + super().__init__(self.num_inputs, self.num_outputs, self.batch, self.inputs, self.target) diff --git a/problems/gym/__init__.py b/problems/gym/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/problems/gym/gym_problem.py b/problems/gym/gym_problem.py new file mode 100644 index 0000000..e69de29 diff --git a/problems/problem.py b/problems/problem.py new file mode 100644 index 0000000..f5a7ee5 --- /dev/null +++ b/problems/problem.py @@ -0,0 +1,15 @@ +class Problem: + def __init__(self, forward_way, num_inputs, num_outputs, batch): + self.forward_way = forward_way + self.batch = batch + self.num_inputs = num_inputs + self.num_outputs = num_outputs + + def refactor_config(self, config): + config.basic.forward_way = self.forward_way + config.basic.num_inputs = self.num_inputs + config.basic.num_outputs = self.num_outputs + config.basic.problem_batch = self.batch + + def evaluate(self, batch_forward_func): + pass diff --git a/utils/default_config.json b/utils/default_config.json index 1b2be0e..273e304 100644 --- a/utils/default_config.json +++ b/utils/default_config.json @@ -5,14 +5,15 @@ "problem_batch": 4, "init_maximum_nodes": 10, "expands_coe": 2, - "pre_compile_times": 3 + "pre_compile_times": 3, + "forward_way": "pop_batch" }, "neat": { "population": { "fitness_criterion": "max", "fitness_threshold": 76, "generation_limit": 100, - "pop_size": 2000, + "pop_size": 1000, "reset_on_extinction": "False" }, "gene": { @@ -56,9 +57,9 @@ "compatibility_weight_coefficient": 0.5, "single_structural_mutation": "False", "conn_add_prob": 0.5, - "conn_delete_prob": 0.5, - "node_add_prob": 0.2, - "node_delete_prob": 0.2 + "conn_delete_prob": 0, + "node_add_prob": 0.1, + "node_delete_prob": 0 }, "species": { "compatibility_threshold": 3,