modify pipeline for "update_by_data";
fix bug in speciate. currently, node_delete and conn_delete can successfully work
This commit is contained in:
@@ -49,7 +49,10 @@ class FuncFit(BaseProblem):
|
||||
state, self.inputs, params
|
||||
)
|
||||
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
|
||||
loss = self.evaluate(state, randkey, act_func, params)
|
||||
if self.return_data:
|
||||
loss, _ = self.evaluate(state, randkey, act_func, params)
|
||||
else:
|
||||
loss = self.evaluate(state, randkey, act_func, params)
|
||||
loss = -loss
|
||||
|
||||
msg = ""
|
||||
|
||||
@@ -4,14 +4,19 @@ from .func_fit import FuncFit
|
||||
|
||||
|
||||
class XOR(FuncFit):
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
|
||||
return np.array(
|
||||
[[0, 0], [0, 1], [1, 0], [1, 1]],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
@property
|
||||
def targets(self):
|
||||
return np.array([[0], [1], [1], [0]])
|
||||
return np.array(
|
||||
[[0], [1], [1], [0]],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
|
||||
@@ -16,12 +16,16 @@ class XOR3d(FuncFit):
|
||||
[1, 0, 1],
|
||||
[1, 1, 0],
|
||||
[1, 1, 1],
|
||||
]
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
@property
|
||||
def targets(self):
|
||||
return np.array([[0], [1], [1], [0], [1], [0], [0], [1]])
|
||||
return np.array(
|
||||
[[0], [1], [1], [0], [1], [0], [0], [1]],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
|
||||
Reference in New Issue
Block a user