Files
tensorneat-mend/problems/function_fitting/enhance_logic.py
2023-05-14 15:27:17 +08:00

55 lines
1.4 KiB
Python

"""
xor problem in multiple dimensions
"""
from itertools import product
import numpy as np
class EnhanceLogic:
def __init__(self, name="xor", n=2):
self.name = name
self.n = n
self.num_inputs = n
self.num_outputs = 1
self.batch = 2 ** n
self.forward_way = 'pop_batch'
self.inputs = np.array(generate_permutations(n), dtype=np.float32)
if self.name == "xor":
self.outputs = np.sum(self.inputs, axis=1) % 2
elif self.name == "and":
self.outputs = np.all(self.inputs==1, axis=1)
elif self.name == "or":
self.outputs = np.any(self.inputs==1, axis=1)
else:
raise NotImplementedError("Only support xor, and, or")
self.outputs = self.outputs[:, np.newaxis]
def refactor_config(self, config):
config.basic.forward_way = self.forward_way
config.basic.num_inputs = self.num_inputs
config.basic.num_outputs = self.num_outputs
config.basic.problem_batch = self.batch
def ask_for_inputs(self):
return self.inputs
def evaluate_predict(self, predict):
# print((predict - self.outputs) ** 2)
return -np.mean((predict - self.outputs) ** 2)
def generate_permutations(n):
permutations = [list(i) for i in product([0, 1], repeat=n)]
return permutations
if __name__ == '__main__':
_ = EnhanceLogic(4)