add method 'create_crossover_function' and 'create_distance_function'

This commit is contained in:
wls2002
2023-05-07 22:16:27 +08:00
parent cec40b254f
commit 47bb593a53
7 changed files with 45 additions and 15 deletions

View File

@@ -6,7 +6,7 @@ import numpy as np
from .species import SpeciesController
from .genome import create_initialize_function, create_mutate_function, create_forward_function
from .genome import batch_crossover
from .genome import create_crossover_function
from .genome import expand, expand_single
@@ -27,6 +27,7 @@ class Pipeline:
self.initialize_func = create_initialize_function(config)
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
self.mutate_func = create_mutate_function(config, self.input_idx, self.output_idx, batch=True)
self.crossover_func = create_crossover_function(batch=True)
self.generation = 0
@@ -102,7 +103,7 @@ class Pipeline:
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
# mutate
mutate_rand_keys = jax.random.split(k2, self.pop_size)