add method 'create_crossover_function' and 'create_distance_function'
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from typing import Callable, List
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from utils import Configer
|
||||
@@ -17,12 +18,13 @@ def evaluate(forward_func: Callable) -> List[float]:
|
||||
:return:
|
||||
"""
|
||||
outs = forward_func(xor_inputs)
|
||||
outs = jax.device_get(outs)
|
||||
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||
return fitnesses.tolist() # returns a list
|
||||
|
||||
|
||||
# @using_cprofile
|
||||
@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
|
||||
@using_cprofile
|
||||
# @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
|
||||
def main():
|
||||
config = Configer.load_config()
|
||||
pipeline = Pipeline(config, seed=11323)
|
||||
|
||||
Reference in New Issue
Block a user