虽然xor问题还是跑不出来,但至少已经确定不是distance的错了
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
from typing import Callable, List
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from utils import Configer
|
||||
@@ -18,8 +17,7 @@ def evaluate(forward_func: Callable) -> List[float]:
|
||||
:return:
|
||||
"""
|
||||
outs = forward_func(xor_inputs)
|
||||
outs = jax.device_get(outs)
|
||||
fitnesses = -np.mean((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||
fitnesses = np.mean((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||
# print(fitnesses)
|
||||
return fitnesses.tolist() # returns a list
|
||||
|
||||
|
||||
Reference in New Issue
Block a user