prepare for experiment
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from .function_fitting_problem import FunctionFittingProblem
|
||||
from .xor import *
|
||||
from .sin import *
|
||||
from .diy import *
|
||||
from .diy import *
|
||||
from .enhance_logic import *
|
||||
54
problems/function_fitting/enhance_logic.py
Normal file
54
problems/function_fitting/enhance_logic.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""
|
||||
xor problem in multiple dimensions
|
||||
"""
|
||||
|
||||
from itertools import product
|
||||
import numpy as np
|
||||
|
||||
|
||||
class EnhanceLogic:
|
||||
def __init__(self, name="xor", n=2):
|
||||
self.name = name
|
||||
self.n = n
|
||||
self.num_inputs = n
|
||||
self.num_outputs = 1
|
||||
self.batch = 2 ** n
|
||||
self.forward_way = 'pop_batch'
|
||||
|
||||
self.inputs = np.array(generate_permutations(n), dtype=np.float32)
|
||||
|
||||
if self.name == "xor":
|
||||
self.outputs = np.sum(self.inputs, axis=1) % 2
|
||||
elif self.name == "and":
|
||||
self.outputs = np.all(self.inputs==1, axis=1)
|
||||
elif self.name == "or":
|
||||
self.outputs = np.any(self.inputs==1, axis=1)
|
||||
else:
|
||||
raise NotImplementedError("Only support xor, and, or")
|
||||
self.outputs = self.outputs[:, np.newaxis]
|
||||
|
||||
|
||||
def refactor_config(self, config):
|
||||
config.basic.forward_way = self.forward_way
|
||||
config.basic.num_inputs = self.num_inputs
|
||||
config.basic.num_outputs = self.num_outputs
|
||||
config.basic.problem_batch = self.batch
|
||||
|
||||
|
||||
def ask_for_inputs(self):
|
||||
return self.inputs
|
||||
|
||||
def evaluate_predict(self, predict):
|
||||
# print((predict - self.outputs) ** 2)
|
||||
return -np.mean((predict - self.outputs) ** 2)
|
||||
|
||||
|
||||
|
||||
def generate_permutations(n):
|
||||
permutations = [list(i) for i in product([0, 1], repeat=n)]
|
||||
|
||||
return permutations
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_ = EnhanceLogic(4)
|
||||
Reference in New Issue
Block a user