prepare for experiment

This commit is contained in:
wls2002
2023-05-14 15:27:17 +08:00
parent 72c9d4167a
commit 2b79f2c903
11 changed files with 252 additions and 62 deletions

View File

@@ -1,4 +1,5 @@
from .function_fitting_problem import FunctionFittingProblem
from .xor import *
from .sin import *
from .diy import *
from .diy import *
from .enhance_logic import *

View File

@@ -0,0 +1,54 @@
"""
xor problem in multiple dimensions
"""
from itertools import product
import numpy as np
class EnhanceLogic:
def __init__(self, name="xor", n=2):
self.name = name
self.n = n
self.num_inputs = n
self.num_outputs = 1
self.batch = 2 ** n
self.forward_way = 'pop_batch'
self.inputs = np.array(generate_permutations(n), dtype=np.float32)
if self.name == "xor":
self.outputs = np.sum(self.inputs, axis=1) % 2
elif self.name == "and":
self.outputs = np.all(self.inputs==1, axis=1)
elif self.name == "or":
self.outputs = np.any(self.inputs==1, axis=1)
else:
raise NotImplementedError("Only support xor, and, or")
self.outputs = self.outputs[:, np.newaxis]
def refactor_config(self, config):
config.basic.forward_way = self.forward_way
config.basic.num_inputs = self.num_inputs
config.basic.num_outputs = self.num_outputs
config.basic.problem_batch = self.batch
def ask_for_inputs(self):
return self.inputs
def evaluate_predict(self, predict):
# print((predict - self.outputs) ** 2)
return -np.mean((predict - self.outputs) ** 2)
def generate_permutations(n):
permutations = [list(i) for i in product([0, 1], repeat=n)]
return permutations
if __name__ == '__main__':
_ = EnhanceLogic(4)