add method 'create_crossover_function' and 'create_distance_function'
This commit is contained in:
@@ -6,7 +6,7 @@ import numpy as np
|
||||
|
||||
from .species import SpeciesController
|
||||
from .genome import create_initialize_function, create_mutate_function, create_forward_function
|
||||
from .genome import batch_crossover
|
||||
from .genome import create_crossover_function
|
||||
from .genome import expand, expand_single
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ class Pipeline:
|
||||
self.initialize_func = create_initialize_function(config)
|
||||
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
|
||||
self.mutate_func = create_mutate_function(config, self.input_idx, self.output_idx, batch=True)
|
||||
self.crossover_func = create_crossover_function(batch=True)
|
||||
|
||||
self.generation = 0
|
||||
|
||||
@@ -102,7 +103,7 @@ class Pipeline:
|
||||
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
|
||||
lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
|
||||
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
|
||||
npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
|
||||
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
|
||||
|
||||
# mutate
|
||||
mutate_rand_keys = jax.random.split(k2, self.pop_size)
|
||||
|
||||
Reference in New Issue
Block a user