虽然xor问题还是跑不出来,但至少已经确定不是distance的错了

This commit is contained in:
wls2002
2023-05-06 23:26:13 +08:00
parent a85e6eba78
commit 414b620dc8
5 changed files with 151 additions and 7 deletions

View File

@@ -1,7 +1,6 @@
from typing import Callable, List
from functools import partial
import jax
import numpy as np
from utils import Configer
@@ -18,8 +17,7 @@ def evaluate(forward_func: Callable) -> List[float]:
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = -np.mean((outs - xor_outputs) ** 2, axis=(1, 2))
fitnesses = np.mean((outs - xor_outputs) ** 2, axis=(1, 2))
# print(fitnesses)
return fitnesses.tolist() # returns a list