fix a bug in stagnation
This commit is contained in:
@@ -23,6 +23,7 @@ def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, cente
|
|||||||
update_species(k1, fitness, species_info, idx2species, center_nodes,
|
update_species(k1, fitness, species_info, idx2species, center_nodes,
|
||||||
center_cons, generation, jit_config)
|
center_cons, generation, jit_config)
|
||||||
|
|
||||||
|
|
||||||
pop_nodes, pop_cons = create_next_generation(k2, pop_nodes, pop_cons, winner, loser,
|
pop_nodes, pop_cons = create_next_generation(k2, pop_nodes, pop_cons, winner, loser,
|
||||||
elite_mask, generation, jit_config)
|
elite_mask, generation, jit_config)
|
||||||
|
|
||||||
@@ -30,6 +31,7 @@ def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, cente
|
|||||||
pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation,
|
pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation,
|
||||||
jit_config)
|
jit_config)
|
||||||
|
|
||||||
|
|
||||||
return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation
|
return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation
|
||||||
|
|
||||||
|
|
||||||
@@ -111,7 +113,7 @@ def stagnation(species_fitness, species_info, center_nodes, center_cons, generat
|
|||||||
species_info = jnp.where(spe_st[:, None], jnp.nan, species_info)
|
species_info = jnp.where(spe_st[:, None], jnp.nan, species_info)
|
||||||
center_nodes = jnp.where(spe_st[:, None, None], jnp.nan, center_nodes)
|
center_nodes = jnp.where(spe_st[:, None, None], jnp.nan, center_nodes)
|
||||||
center_cons = jnp.where(spe_st[:, None, None], jnp.nan, center_cons)
|
center_cons = jnp.where(spe_st[:, None, None], jnp.nan, center_cons)
|
||||||
species_fitness = jnp.where(spe_st, jnp.nan, species_fitness)
|
species_fitness = jnp.where(spe_st, -jnp.inf, species_fitness)
|
||||||
|
|
||||||
return species_fitness, species_info, center_nodes, center_cons
|
return species_fitness, species_info, center_nodes, center_cons
|
||||||
|
|
||||||
@@ -269,6 +271,7 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
|||||||
# part 2: assign members to each species
|
# part 2: assign members to each species
|
||||||
def cond_func(carry):
|
def cond_func(carry):
|
||||||
i, i2s, cn, cc, si, o2c, ck = carry # si is short for species_info, ck is short for current key
|
i, i2s, cn, cc, si, o2c, ck = carry # si is short for species_info, ck is short for current key
|
||||||
|
jax.debug.print("{}, {}", i, i2s)
|
||||||
not_all_assigned = jnp.any(jnp.isnan(i2s))
|
not_all_assigned = jnp.any(jnp.isnan(i2s))
|
||||||
not_reach_species_upper_bounds = i < species_size
|
not_reach_species_upper_bounds = i < species_size
|
||||||
return not_all_assigned & not_reach_species_upper_bounds
|
return not_all_assigned & not_reach_species_upper_bounds
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ num_inputs = 2
|
|||||||
num_outputs = 1
|
num_outputs = 1
|
||||||
init_maximum_nodes = 50
|
init_maximum_nodes = 50
|
||||||
init_maximum_connections = 50
|
init_maximum_connections = 50
|
||||||
init_maximum_species = 100
|
init_maximum_species = 10
|
||||||
expand_coe = 1.5
|
expand_coe = 1.5
|
||||||
pre_expand_threshold = 0.75
|
pre_expand_threshold = 0.75
|
||||||
forward_way = "pop"
|
forward_way = "pop"
|
||||||
|
|||||||
117
examples/debug.py
Normal file
117
examples/debug.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
import pickle
|
||||||
|
|
||||||
|
import jax
|
||||||
|
from jax import numpy as jnp, jit, vmap
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from configs import Configer
|
||||||
|
from algorithms.neat import initialize_genomes
|
||||||
|
from algorithms.neat import tell
|
||||||
|
from algorithms.neat import unflatten_connections, topological_sort, create_forward_function
|
||||||
|
|
||||||
|
jax.config.update("jax_disable_jit", True)
|
||||||
|
|
||||||
|
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||||
|
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||||
|
|
||||||
|
def evaluate(forward_func):
|
||||||
|
"""
|
||||||
|
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
outs = forward_func(xor_inputs)
|
||||||
|
outs = jax.device_get(outs)
|
||||||
|
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||||
|
return fitnesses
|
||||||
|
|
||||||
|
|
||||||
|
def get_fitnesses(pop_nodes, pop_cons, pop_unflatten_connections, pop_topological_sort, forward_func):
|
||||||
|
u_pop_cons = pop_unflatten_connections(pop_nodes, pop_cons)
|
||||||
|
pop_seqs = pop_topological_sort(pop_nodes, u_pop_cons)
|
||||||
|
func = lambda x: forward_func(x, pop_seqs, pop_nodes, u_pop_cons)
|
||||||
|
|
||||||
|
return evaluate(func)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def equal(ar1, ar2):
|
||||||
|
if ar1.shape != ar2.shape:
|
||||||
|
return False
|
||||||
|
|
||||||
|
nan_mask1 = jnp.isnan(ar1)
|
||||||
|
nan_mask2 = jnp.isnan(ar2)
|
||||||
|
|
||||||
|
return jnp.all((ar1 == ar2) | (nan_mask1 & nan_mask2))
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# initialize
|
||||||
|
config = Configer.load_config("xor.ini")
|
||||||
|
jit_config = Configer.create_jit_config(config) # config used in jit-able functions
|
||||||
|
|
||||||
|
P = config['pop_size']
|
||||||
|
N = config['init_maximum_nodes']
|
||||||
|
C = config['init_maximum_connections']
|
||||||
|
S = config['init_maximum_species']
|
||||||
|
randkey = jax.random.PRNGKey(6)
|
||||||
|
np.random.seed(6)
|
||||||
|
|
||||||
|
pop_nodes, pop_cons = initialize_genomes(N, C, config)
|
||||||
|
species_info = np.full((S, 4), np.nan)
|
||||||
|
species_info[0, :] = 0, -np.inf, 0, P
|
||||||
|
idx2species = np.zeros(P, dtype=np.float32)
|
||||||
|
center_nodes = np.full((S, N, 5), np.nan)
|
||||||
|
center_cons = np.full((S, C, 4), np.nan)
|
||||||
|
center_nodes[0, :, :] = pop_nodes[0, :, :]
|
||||||
|
center_cons[0, :, :] = pop_cons[0, :, :]
|
||||||
|
generation = 0
|
||||||
|
|
||||||
|
pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons = jax.device_put(
|
||||||
|
[pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons])
|
||||||
|
|
||||||
|
pop_unflatten_connections = jit(vmap(unflatten_connections))
|
||||||
|
pop_topological_sort = jit(vmap(topological_sort))
|
||||||
|
forward = create_forward_function(config)
|
||||||
|
|
||||||
|
|
||||||
|
while True:
|
||||||
|
fitness = get_fitnesses(pop_nodes, pop_cons, pop_unflatten_connections, pop_topological_sort, forward)
|
||||||
|
|
||||||
|
last_max = np.max(fitness)
|
||||||
|
|
||||||
|
info = [fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation,
|
||||||
|
jit_config]
|
||||||
|
|
||||||
|
with open('list.pkl', 'wb') as f:
|
||||||
|
# 使用pickle模块的dump函数来保存list
|
||||||
|
pickle.dump(info, f)
|
||||||
|
|
||||||
|
randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation = tell(
|
||||||
|
fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation,
|
||||||
|
jit_config)
|
||||||
|
|
||||||
|
fitness = get_fitnesses(pop_nodes, pop_cons, pop_unflatten_connections, pop_topological_sort, forward)
|
||||||
|
current_max = np.max(fitness)
|
||||||
|
print(last_max, current_max)
|
||||||
|
assert current_max >= last_max, f"current_max: {current_max}, last_max: {last_max}"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# main()
|
||||||
|
config = Configer.load_config("xor.ini")
|
||||||
|
pop_unflatten_connections = jit(vmap(unflatten_connections))
|
||||||
|
pop_topological_sort = jit(vmap(topological_sort))
|
||||||
|
forward = create_forward_function(config)
|
||||||
|
|
||||||
|
with open('list.pkl', 'rb') as f:
|
||||||
|
# 使用pickle模块的dump函数来保存list
|
||||||
|
fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, i, jit_config = pickle.load(
|
||||||
|
f)
|
||||||
|
|
||||||
|
print(np.max(fitness))
|
||||||
|
randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, _ = tell(
|
||||||
|
fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, i,
|
||||||
|
jit_config)
|
||||||
|
fitness = get_fitnesses(pop_nodes, pop_cons, pop_unflatten_connections, pop_topological_sort, forward)
|
||||||
|
print(np.max(fitness))
|
||||||
@@ -39,7 +39,6 @@ class Pipeline:
|
|||||||
self.center_cons[0, :, :] = self.pop_cons[0, :, :]
|
self.center_cons[0, :, :] = self.pop_cons[0, :, :]
|
||||||
|
|
||||||
self.best_fitness = float('-inf')
|
self.best_fitness = float('-inf')
|
||||||
self.best_genome = None
|
|
||||||
self.generation_timestamp = time.time()
|
self.generation_timestamp = time.time()
|
||||||
|
|
||||||
self.evaluate_time = 0
|
self.evaluate_time = 0
|
||||||
|
|||||||
Reference in New Issue
Block a user