change repo structure; modify readme

This commit is contained in:
wls2002
2024-03-26 21:58:27 +08:00
parent 6970e6a6d5
commit 47dbcbea80
69 changed files with 74 additions and 60 deletions

View File

@@ -0,0 +1,3 @@
from .func_fit import FuncFit
from .xor import XOR
from .xor3d import XOR3d

View File

@@ -0,0 +1,67 @@
import jax
import jax.numpy as jnp
from utils import State
from .. import BaseProblem
class FuncFit(BaseProblem):
jitable = True
def __init__(self,
error_method: str = 'mse'
):
super().__init__()
assert error_method in {'mse', 'rmse', 'mae', 'mape'}
self.error_method = error_method
def setup(self, randkey, state: State = State()):
return state
def evaluate(self, randkey, state, act_func, params):
predict = jax.vmap(act_func, in_axes=(0, None))(self.inputs, params)
if self.error_method == 'mse':
loss = jnp.mean((predict - self.targets) ** 2)
elif self.error_method == 'rmse':
loss = jnp.sqrt(jnp.mean((predict - self.targets) ** 2))
elif self.error_method == 'mae':
loss = jnp.mean(jnp.abs(predict - self.targets))
elif self.error_method == 'mape':
loss = jnp.mean(jnp.abs((predict - self.targets) / self.targets))
else:
raise NotImplementedError
return -loss
def show(self, randkey, state, act_func, params, *args, **kwargs):
predict = jax.vmap(act_func, in_axes=(0, None))(self.inputs, params)
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
loss = -self.evaluate(randkey, state, act_func, params)
msg = ""
for i in range(inputs.shape[0]):
msg += f"input: {inputs[i]}, target: {target[i]}, predict: {predict[i]}\n"
msg += f"loss: {loss}\n"
print(msg)
@property
def inputs(self):
raise NotImplementedError
@property
def targets(self):
raise NotImplementedError
@property
def input_shape(self):
raise NotImplementedError
@property
def output_shape(self):
raise NotImplementedError

View File

@@ -0,0 +1,35 @@
import numpy as np
from .func_fit import FuncFit
class XOR(FuncFit):
def __init__(self, error_method: str = 'mse'):
super().__init__(error_method)
@property
def inputs(self):
return np.array([
[0, 0],
[0, 1],
[1, 0],
[1, 1]
])
@property
def targets(self):
return np.array([
[0],
[1],
[1],
[0]
])
@property
def input_shape(self):
return 4, 2
@property
def output_shape(self):
return 4, 1

View File

@@ -0,0 +1,43 @@
import numpy as np
from .func_fit import FuncFit
class XOR3d(FuncFit):
def __init__(self, error_method: str = 'mse'):
super().__init__(error_method)
@property
def inputs(self):
return np.array([
[0, 0, 0],
[0, 0, 1],
[0, 1, 0],
[0, 1, 1],
[1, 0, 0],
[1, 0, 1],
[1, 1, 0],
[1, 1, 1],
])
@property
def targets(self):
return np.array([
[0],
[1],
[1],
[0],
[1],
[0],
[0],
[1]
])
@property
def input_shape(self):
return 8, 3
@property
def output_shape(self):
return 8, 1