modify pipeline for "update_by_data";

fix bug in speciate. currently, node_delete and conn_delete can successfully work
This commit is contained in:
wls2002
2024-05-31 15:32:56 +08:00
parent 3ea9986bd4
commit 6aa9011043
12 changed files with 132 additions and 45 deletions

View File

@@ -49,7 +49,10 @@ class FuncFit(BaseProblem):
state, self.inputs, params
)
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
loss = self.evaluate(state, randkey, act_func, params)
if self.return_data:
loss, _ = self.evaluate(state, randkey, act_func, params)
else:
loss = self.evaluate(state, randkey, act_func, params)
loss = -loss
msg = ""