add method 'create_crossover_function' and 'create_distance_function'

This commit is contained in:
wls2002
2023-05-07 22:16:27 +08:00
parent cec40b254f
commit 47bb593a53
7 changed files with 45 additions and 15 deletions

View File

@@ -1,6 +1,7 @@
from typing import Callable, List
from functools import partial
import jax
import numpy as np
from utils import Configer
@@ -17,12 +18,13 @@ def evaluate(forward_func: Callable) -> List[float]:
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses.tolist() # returns a list
# @using_cprofile
@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
@using_cprofile
# @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
def main():
config = Configer.load_config()
pipeline = Pipeline(config, seed=11323)