use black format all files;
remove "return state" for functions which will be executed in vmap; recover randkey as args in mutation methods
This commit is contained in:
@@ -6,23 +6,23 @@ from .base import BaseSpecies
|
||||
|
||||
|
||||
class DefaultSpecies(BaseSpecies):
|
||||
|
||||
def __init__(self,
|
||||
genome: BaseGenome,
|
||||
pop_size,
|
||||
species_size,
|
||||
compatibility_disjoint: float = 1.0,
|
||||
compatibility_weight: float = 0.4,
|
||||
max_stagnation: int = 15,
|
||||
species_elitism: int = 2,
|
||||
spawn_number_change_rate: float = 0.5,
|
||||
genome_elitism: int = 2,
|
||||
survival_threshold: float = 0.2,
|
||||
min_species_size: int = 1,
|
||||
compatibility_threshold: float = 3.,
|
||||
initialize_method: str = 'one_hidden_node',
|
||||
# {'one_hidden_node', 'dense_hideen_layer', 'no_hidden_random'}
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
genome: BaseGenome,
|
||||
pop_size,
|
||||
species_size,
|
||||
compatibility_disjoint: float = 1.0,
|
||||
compatibility_weight: float = 0.4,
|
||||
max_stagnation: int = 15,
|
||||
species_elitism: int = 2,
|
||||
spawn_number_change_rate: float = 0.5,
|
||||
genome_elitism: int = 2,
|
||||
survival_threshold: float = 0.2,
|
||||
min_species_size: int = 1,
|
||||
compatibility_threshold: float = 3.0,
|
||||
initialize_method: str = "one_hidden_node",
|
||||
# {'one_hidden_node', 'dense_hideen_layer', 'no_hidden_random'}
|
||||
):
|
||||
self.genome = genome
|
||||
self.pop_size = pop_size
|
||||
self.species_size = species_size
|
||||
@@ -40,21 +40,38 @@ class DefaultSpecies(BaseSpecies):
|
||||
|
||||
self.species_arange = jnp.arange(self.species_size)
|
||||
|
||||
def setup(self, key, state=State()):
|
||||
k1, k2 = jax.random.split(key, 2)
|
||||
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome, k1, self.initialize_method)
|
||||
def setup(self, state=State()):
|
||||
state = self.genome.setup(state)
|
||||
k1, randkey = jax.random.split(state.randkey, 2)
|
||||
pop_nodes, pop_conns = initialize_population(
|
||||
self.pop_size, self.genome, k1, self.initialize_method
|
||||
)
|
||||
|
||||
species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species
|
||||
best_fitness = jnp.full((self.species_size,), jnp.nan) # the best fitness of each species
|
||||
last_improved = jnp.full((self.species_size,), jnp.nan) # the last generation that the species improved
|
||||
member_count = jnp.full((self.species_size,), jnp.nan) # the number of members of each species
|
||||
species_keys = jnp.full(
|
||||
(self.species_size,), jnp.nan
|
||||
) # the unique index (primary key) for each species
|
||||
best_fitness = jnp.full(
|
||||
(self.species_size,), jnp.nan
|
||||
) # the best fitness of each species
|
||||
last_improved = jnp.full(
|
||||
(self.species_size,), jnp.nan
|
||||
) # the last 1 that the species improved
|
||||
member_count = jnp.full(
|
||||
(self.species_size,), jnp.nan
|
||||
) # the number of members of each species
|
||||
idx2species = jnp.zeros(self.pop_size) # the species index of each individual
|
||||
|
||||
# nodes for each center genome of each species
|
||||
center_nodes = jnp.full((self.species_size, self.genome.max_nodes, self.genome.node_gene.length), jnp.nan)
|
||||
center_nodes = jnp.full(
|
||||
(self.species_size, self.genome.max_nodes, self.genome.node_gene.length),
|
||||
jnp.nan,
|
||||
)
|
||||
|
||||
# connections for each center genome of each species
|
||||
center_conns = jnp.full((self.species_size, self.genome.max_conns, self.genome.conn_gene.length), jnp.nan)
|
||||
center_conns = jnp.full(
|
||||
(self.species_size, self.genome.max_conns, self.genome.conn_gene.length),
|
||||
jnp.nan,
|
||||
)
|
||||
|
||||
species_keys = species_keys.at[0].set(0)
|
||||
best_fitness = best_fitness.at[0].set(-jnp.inf)
|
||||
@@ -66,7 +83,7 @@ class DefaultSpecies(BaseSpecies):
|
||||
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
|
||||
|
||||
return state.register(
|
||||
species_randkey=k2,
|
||||
randkey=randkey,
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
species_keys=species_keys,
|
||||
@@ -80,14 +97,14 @@ class DefaultSpecies(BaseSpecies):
|
||||
)
|
||||
|
||||
def ask(self, state):
|
||||
return state.pop_nodes, state.pop_conns
|
||||
return state, state.pop_nodes, state.pop_conns
|
||||
|
||||
def update_species(self, state, fitness, generation):
|
||||
def update_species(self, state, fitness):
|
||||
# update the fitness of each species
|
||||
species_fitness = self.update_species_fitness(state, fitness)
|
||||
state, species_fitness = self.update_species_fitness(state, fitness)
|
||||
|
||||
# stagnation species
|
||||
state, species_fitness = self.stagnation(state, generation, species_fitness)
|
||||
state, species_fitness = self.stagnation(state, species_fitness)
|
||||
|
||||
# sort species_info by their fitness. (also push nan to the end)
|
||||
sort_indices = jnp.argsort(species_fitness)[::-1]
|
||||
@@ -101,11 +118,13 @@ class DefaultSpecies(BaseSpecies):
|
||||
)
|
||||
|
||||
# decide the number of members of each species by their fitness
|
||||
spawn_number = self.cal_spawn_numbers(state)
|
||||
state, spawn_number = self.cal_spawn_numbers(state)
|
||||
|
||||
k1, k2 = jax.random.split(state.randkey)
|
||||
# crossover info
|
||||
winner, loser, elite_mask = self.create_crossover_pair(state, k1, spawn_number, fitness)
|
||||
winner, loser, elite_mask = self.create_crossover_pair(
|
||||
state, k1, spawn_number, fitness
|
||||
)
|
||||
|
||||
return state.update(randkey=k2), winner, loser, elite_mask
|
||||
|
||||
@@ -116,42 +135,50 @@ class DefaultSpecies(BaseSpecies):
|
||||
"""
|
||||
|
||||
def aux_func(idx):
|
||||
s_fitness = jnp.where(state.idx2species == state.species_keys[idx], fitness, -jnp.inf)
|
||||
s_fitness = jnp.where(
|
||||
state.idx2species == state.species_keys[idx], fitness, -jnp.inf
|
||||
)
|
||||
val = jnp.max(s_fitness)
|
||||
return val
|
||||
|
||||
return jax.vmap(aux_func)(self.species_arange)
|
||||
return state, jax.vmap(aux_func)(self.species_arange)
|
||||
|
||||
def stagnation(self, state, generation, species_fitness):
|
||||
def stagnation(self, state, species_fitness):
|
||||
"""
|
||||
stagnation species.
|
||||
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
|
||||
elitism species never stagnation
|
||||
|
||||
generation: the current generation
|
||||
"""
|
||||
|
||||
def check_stagnation(idx):
|
||||
# determine whether the species stagnation
|
||||
st = (
|
||||
(species_fitness[idx] <= state.best_fitness[idx]) & # not better than the best fitness of the species
|
||||
(generation - state.last_improved[idx] > self.max_stagnation) # for a long time
|
||||
)
|
||||
species_fitness[idx] <= state.best_fitness[idx]
|
||||
) & ( # not better than the best fitness of the species
|
||||
state.generation - state.last_improved[idx] > self.max_stagnation
|
||||
) # for a long time
|
||||
|
||||
# update last_improved and best_fitness
|
||||
li, bf = jax.lax.cond(
|
||||
species_fitness[idx] > state.best_fitness[idx],
|
||||
lambda: (generation, species_fitness[idx]), # update
|
||||
lambda: (state.last_improved[idx], state.best_fitness[idx]) # not update
|
||||
lambda: (state.generation, species_fitness[idx]), # update
|
||||
lambda: (
|
||||
state.last_improved[idx],
|
||||
state.best_fitness[idx],
|
||||
), # not update
|
||||
)
|
||||
|
||||
return st, bf, li
|
||||
|
||||
spe_st, best_fitness, last_improved = jax.vmap(check_stagnation)(self.species_arange)
|
||||
spe_st, best_fitness, last_improved = jax.vmap(check_stagnation)(
|
||||
self.species_arange
|
||||
)
|
||||
|
||||
# elite species will not be stagnation
|
||||
species_rank = rank_elements(species_fitness)
|
||||
spe_st = jnp.where(species_rank < self.species_elitism, False, spe_st) # elitism never stagnation
|
||||
spe_st = jnp.where(
|
||||
species_rank < self.species_elitism, False, spe_st
|
||||
) # elitism never stagnation
|
||||
|
||||
# set stagnation species to nan
|
||||
def update_func(idx):
|
||||
@@ -173,8 +200,8 @@ class DefaultSpecies(BaseSpecies):
|
||||
state.member_count[idx],
|
||||
species_fitness[idx],
|
||||
state.center_nodes[idx],
|
||||
state.center_conns[idx]
|
||||
) # not stagnation species
|
||||
state.center_conns[idx],
|
||||
), # not stagnation species
|
||||
)
|
||||
|
||||
(
|
||||
@@ -184,18 +211,20 @@ class DefaultSpecies(BaseSpecies):
|
||||
member_count,
|
||||
species_fitness,
|
||||
center_nodes,
|
||||
center_conns
|
||||
) = (
|
||||
jax.vmap(update_func)(self.species_arange))
|
||||
center_conns,
|
||||
) = jax.vmap(update_func)(self.species_arange)
|
||||
|
||||
return state.update(
|
||||
species_keys=species_keys,
|
||||
best_fitness=best_fitness,
|
||||
last_improved=last_improved,
|
||||
member_count=member_count,
|
||||
center_nodes=center_nodes,
|
||||
center_conns=center_conns,
|
||||
), species_fitness
|
||||
return (
|
||||
state.update(
|
||||
species_keys=species_keys,
|
||||
best_fitness=best_fitness,
|
||||
last_improved=last_improved,
|
||||
member_count=member_count,
|
||||
center_nodes=center_nodes,
|
||||
center_conns=center_conns,
|
||||
),
|
||||
species_fitness,
|
||||
)
|
||||
|
||||
def cal_spawn_numbers(self, state):
|
||||
"""
|
||||
@@ -209,17 +238,26 @@ class DefaultSpecies(BaseSpecies):
|
||||
|
||||
is_species_valid = ~jnp.isnan(species_keys)
|
||||
valid_species_num = jnp.sum(is_species_valid)
|
||||
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
|
||||
denominator = (
|
||||
(valid_species_num + 1) * valid_species_num / 2
|
||||
) # obtain 3 + 2 + 1 = 6
|
||||
|
||||
rank_score = valid_species_num - self.species_arange # obtain [3, 2, 1]
|
||||
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
|
||||
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
|
||||
spawn_number_rate = jnp.where(
|
||||
is_species_valid, spawn_number_rate, 0
|
||||
) # set invalid species to 0
|
||||
|
||||
target_spawn_number = jnp.floor(spawn_number_rate * self.pop_size) # calculate member
|
||||
target_spawn_number = jnp.floor(
|
||||
spawn_number_rate * self.pop_size
|
||||
) # calculate member
|
||||
|
||||
# Avoid too much variation of numbers for a species
|
||||
previous_size = state.member_count
|
||||
spawn_number = previous_size + (target_spawn_number - previous_size) * self.spawn_number_change_rate
|
||||
spawn_number = (
|
||||
previous_size
|
||||
+ (target_spawn_number - previous_size) * self.spawn_number_change_rate
|
||||
)
|
||||
spawn_number = spawn_number.astype(jnp.int32)
|
||||
|
||||
# must control the sum of spawn_number to be equal to pop_size
|
||||
@@ -228,9 +266,9 @@ class DefaultSpecies(BaseSpecies):
|
||||
# add error to the first species to control the sum of spawn_number
|
||||
spawn_number = spawn_number.at[0].add(error)
|
||||
|
||||
return spawn_number
|
||||
return state, spawn_number
|
||||
|
||||
def create_crossover_pair(self, state, randkey, spawn_number, fitness):
|
||||
def create_crossover_pair(self, state, spawn_number, fitness):
|
||||
s_idx = self.species_arange
|
||||
p_idx = jnp.arange(self.pop_size)
|
||||
|
||||
@@ -241,10 +279,18 @@ class DefaultSpecies(BaseSpecies):
|
||||
members_fitness = jnp.where(members, fitness, -jnp.inf)
|
||||
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
|
||||
|
||||
survive_size = jnp.floor(self.survival_threshold * members_num).astype(jnp.int32)
|
||||
survive_size = jnp.floor(self.survival_threshold * members_num).astype(
|
||||
jnp.int32
|
||||
)
|
||||
|
||||
select_pro = (p_idx < survive_size) / survive_size
|
||||
fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, self.pop_size), replace=True, p=select_pro)
|
||||
fa, ma = jax.random.choice(
|
||||
key,
|
||||
sorted_member_indices,
|
||||
shape=(2, self.pop_size),
|
||||
replace=True,
|
||||
p=select_pro,
|
||||
)
|
||||
|
||||
# elite
|
||||
fa = jnp.where(p_idx < self.genome_elitism, sorted_member_indices, fa)
|
||||
@@ -252,7 +298,10 @@ class DefaultSpecies(BaseSpecies):
|
||||
elite = jnp.where(p_idx < self.genome_elitism, True, False)
|
||||
return fa, ma, elite
|
||||
|
||||
fas, mas, elites = jax.vmap(aux_func)(jax.random.split(randkey, self.species_size), s_idx)
|
||||
randkey_, randkey = jax.random.split(state.randkey)
|
||||
fas, mas, elites = jax.vmap(aux_func)(
|
||||
jax.random.split(randkey_, self.species_size), s_idx
|
||||
)
|
||||
|
||||
spawn_number_cum = jnp.cumsum(spawn_number)
|
||||
|
||||
@@ -261,7 +310,11 @@ class DefaultSpecies(BaseSpecies):
|
||||
|
||||
# elite genomes are at the beginning of the species
|
||||
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
|
||||
return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species]
|
||||
return (
|
||||
fas[loc, idx_in_species],
|
||||
mas[loc, idx_in_species],
|
||||
elites[loc, idx_in_species],
|
||||
)
|
||||
|
||||
part1, part2, elite_mask = jax.vmap(aux_func)(p_idx)
|
||||
|
||||
@@ -269,14 +322,18 @@ class DefaultSpecies(BaseSpecies):
|
||||
winner = jnp.where(is_part1_win, part1, part2)
|
||||
loser = jnp.where(is_part1_win, part2, part1)
|
||||
|
||||
return winner, loser, elite_mask
|
||||
return state(randkey=randkey), winner, loser, elite_mask
|
||||
|
||||
def speciate(self, state, generation):
|
||||
def speciate(self, state):
|
||||
# prepare distance functions
|
||||
o2p_distance_func = jax.vmap(self.distance, in_axes=(None, None, 0, 0)) # one to population
|
||||
o2p_distance_func = jax.vmap(
|
||||
self.distance, in_axes=(None, None, 0, 0)
|
||||
) # one to population
|
||||
|
||||
# idx to specie key
|
||||
idx2species = jnp.full((self.pop_size,), jnp.nan) # NaN means not assigned to any species
|
||||
idx2species = jnp.full(
|
||||
(self.pop_size,), jnp.nan
|
||||
) # NaN means not assigned to any species
|
||||
|
||||
# the distance between genomes to its center genomes
|
||||
o2c_distances = jnp.full((self.pop_size,), jnp.inf)
|
||||
@@ -286,15 +343,16 @@ class DefaultSpecies(BaseSpecies):
|
||||
# i, idx2species, center_nodes, center_conns, o2c_distances
|
||||
i, i2s, cns, ccs, o2c = carry
|
||||
|
||||
return (
|
||||
(i < self.species_size) &
|
||||
(~jnp.isnan(state.species_keys[i]))
|
||||
return (i < self.species_size) & (
|
||||
~jnp.isnan(state.species_keys[i])
|
||||
) # current species is existing
|
||||
|
||||
def body_func(carry):
|
||||
i, i2s, cns, ccs, o2c = carry
|
||||
|
||||
distances = o2p_distance_func(cns[i], ccs[i], state.pop_nodes, state.pop_conns)
|
||||
distances = o2p_distance_func(
|
||||
cns[i], ccs[i], state.pop_nodes, state.pop_conns
|
||||
)
|
||||
|
||||
# find the closest one
|
||||
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
|
||||
@@ -308,9 +366,11 @@ class DefaultSpecies(BaseSpecies):
|
||||
|
||||
return i + 1, i2s, cns, ccs, o2c
|
||||
|
||||
_, idx2species, center_nodes, center_conns, o2c_distances = \
|
||||
jax.lax.while_loop(cond_func, body_func,
|
||||
(0, idx2species, state.center_nodes, state.center_conns, o2c_distances))
|
||||
_, idx2species, center_nodes, center_conns, o2c_distances = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(0, idx2species, state.center_nodes, state.center_conns, o2c_distances),
|
||||
)
|
||||
|
||||
state = state.update(
|
||||
idx2species=idx2species,
|
||||
@@ -326,7 +386,9 @@ class DefaultSpecies(BaseSpecies):
|
||||
current_species_existed = ~jnp.isnan(sk[i])
|
||||
not_all_assigned = jnp.any(jnp.isnan(i2s))
|
||||
not_reach_species_upper_bounds = i < self.species_size
|
||||
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
|
||||
return not_reach_species_upper_bounds & (
|
||||
current_species_existed | not_all_assigned
|
||||
)
|
||||
|
||||
def body_func(carry):
|
||||
i, i2s, cns, ccs, sk, o2c, nsk = carry
|
||||
@@ -335,7 +397,7 @@ class DefaultSpecies(BaseSpecies):
|
||||
jnp.isnan(sk[i]), # whether the current species is existing or not
|
||||
create_new_species, # if not existing, create a new specie
|
||||
update_exist_specie, # if existing, update the specie
|
||||
(i, i2s, cns, ccs, sk, o2c, nsk)
|
||||
(i, i2s, cns, ccs, sk, o2c, nsk),
|
||||
)
|
||||
|
||||
return i + 1, i2s, cns, ccs, sk, o2c, nsk
|
||||
@@ -371,7 +433,9 @@ class DefaultSpecies(BaseSpecies):
|
||||
|
||||
def speciate_by_threshold(i, i2s, cns, ccs, sk, o2c):
|
||||
# distance between such center genome and ppo genomes
|
||||
o2p_distance = o2p_distance_func(cns[i], ccs[i], state.pop_nodes, state.pop_conns)
|
||||
o2p_distance = o2p_distance_func(
|
||||
cns[i], ccs[i], state.pop_nodes, state.pop_conns
|
||||
)
|
||||
|
||||
close_enough_mask = o2p_distance < self.compatibility_threshold
|
||||
# when a genome is not assigned or the distance between its current center is bigger than this center
|
||||
@@ -388,11 +452,26 @@ class DefaultSpecies(BaseSpecies):
|
||||
return i2s, o2c
|
||||
|
||||
# update idx2species
|
||||
_, idx2species, center_nodes, center_conns, species_keys, _, next_species_key = jax.lax.while_loop(
|
||||
(
|
||||
_,
|
||||
idx2species,
|
||||
center_nodes,
|
||||
center_conns,
|
||||
species_keys,
|
||||
_,
|
||||
next_species_key,
|
||||
) = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(0, state.idx2species, center_nodes, center_conns, state.species_keys, o2c_distances,
|
||||
state.next_species_key)
|
||||
(
|
||||
0,
|
||||
state.idx2species,
|
||||
center_nodes,
|
||||
center_conns,
|
||||
state.species_keys,
|
||||
o2c_distances,
|
||||
state.next_species_key,
|
||||
),
|
||||
)
|
||||
|
||||
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
||||
@@ -402,14 +481,18 @@ class DefaultSpecies(BaseSpecies):
|
||||
# complete info of species which is created in this generation
|
||||
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.best_fitness)
|
||||
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.best_fitness)
|
||||
last_improved = jnp.where(new_created_mask, generation, state.last_improved)
|
||||
last_improved = jnp.where(
|
||||
new_created_mask, state.generation, state.last_improved
|
||||
)
|
||||
|
||||
# update members count
|
||||
def count_members(idx):
|
||||
return jax.lax.cond(
|
||||
jnp.isnan(species_keys[idx]), # if the species is not existing
|
||||
lambda: jnp.nan, # nan
|
||||
lambda: jnp.sum(idx2species == species_keys[idx], dtype=jnp.float32) # count members
|
||||
lambda: jnp.sum(
|
||||
idx2species == species_keys[idx], dtype=jnp.float32
|
||||
), # count members
|
||||
)
|
||||
|
||||
member_count = jax.vmap(count_members)(self.species_arange)
|
||||
@@ -422,7 +505,7 @@ class DefaultSpecies(BaseSpecies):
|
||||
idx2species=idx2species,
|
||||
center_nodes=center_nodes,
|
||||
center_conns=center_conns,
|
||||
next_species_key=next_species_key
|
||||
next_species_key=next_species_key,
|
||||
)
|
||||
|
||||
def distance(self, nodes1, conns1, nodes2, conns2):
|
||||
@@ -446,7 +529,9 @@ class DefaultSpecies(BaseSpecies):
|
||||
keys = nodes[:, 0]
|
||||
sorted_indices = jnp.argsort(keys, axis=0)
|
||||
nodes = nodes[sorted_indices]
|
||||
nodes = jnp.concatenate([nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
|
||||
nodes = jnp.concatenate(
|
||||
[nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0
|
||||
) # add a nan row to the end
|
||||
fr, sr = nodes[:-1], nodes[1:] # first row, second row
|
||||
|
||||
# flag location of homologous nodes
|
||||
@@ -460,7 +545,10 @@ class DefaultSpecies(BaseSpecies):
|
||||
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
|
||||
homologous_distance = jnp.sum(hnd * intersect_mask)
|
||||
|
||||
val = non_homologous_cnt * self.compatibility_disjoint + homologous_distance * self.compatibility_weight
|
||||
val = (
|
||||
non_homologous_cnt * self.compatibility_disjoint
|
||||
+ homologous_distance * self.compatibility_weight
|
||||
)
|
||||
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
|
||||
|
||||
@@ -476,7 +564,9 @@ class DefaultSpecies(BaseSpecies):
|
||||
keys = cons[:, :2]
|
||||
sorted_indices = jnp.lexsort(keys.T[::-1])
|
||||
cons = cons[sorted_indices]
|
||||
cons = jnp.concatenate([cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
|
||||
cons = jnp.concatenate(
|
||||
[cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0
|
||||
) # add a nan row to the end
|
||||
fr, sr = cons[:-1], cons[1:] # first row, second row
|
||||
|
||||
# both genome has such connection
|
||||
@@ -487,19 +577,22 @@ class DefaultSpecies(BaseSpecies):
|
||||
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
|
||||
homologous_distance = jnp.sum(hcd * intersect_mask)
|
||||
|
||||
val = non_homologous_cnt * self.compatibility_disjoint + homologous_distance * self.compatibility_weight
|
||||
val = (
|
||||
non_homologous_cnt * self.compatibility_disjoint
|
||||
+ homologous_distance * self.compatibility_weight
|
||||
)
|
||||
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt)
|
||||
|
||||
|
||||
def initialize_population(pop_size, genome, randkey, init_method='default'):
|
||||
def initialize_population(pop_size, genome, randkey, init_method="default"):
|
||||
rand_keys = jax.random.split(randkey, pop_size)
|
||||
|
||||
if init_method == 'one_hidden_node':
|
||||
if init_method == "one_hidden_node":
|
||||
init_func = init_one_hidden_node
|
||||
elif init_method == 'dense_hideen_layer':
|
||||
elif init_method == "dense_hideen_layer":
|
||||
init_func = init_dense_hideen_layer
|
||||
elif init_method == 'no_hidden_random':
|
||||
elif init_method == "no_hidden_random":
|
||||
init_func = init_no_hidden_random
|
||||
else:
|
||||
raise ValueError("Unknown initialization method: {}".format(init_method))
|
||||
@@ -521,12 +614,16 @@ def init_one_hidden_node(genome, randkey):
|
||||
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||
nodes = nodes.at[new_node_key, 0].set(new_node_key)
|
||||
|
||||
rand_keys_nodes = jax.random.split(randkey, num=len(input_idx) + len(output_idx) + 1)
|
||||
input_keys, output_keys, hidden_key = rand_keys_nodes[:len(input_idx)], rand_keys_nodes[
|
||||
len(input_idx):len(input_idx) + len(
|
||||
output_idx)], rand_keys_nodes[-1]
|
||||
rand_keys_nodes = jax.random.split(
|
||||
randkey, num=len(input_idx) + len(output_idx) + 1
|
||||
)
|
||||
input_keys, output_keys, hidden_key = (
|
||||
rand_keys_nodes[: len(input_idx)],
|
||||
rand_keys_nodes[len(input_idx) : len(input_idx) + len(output_idx)],
|
||||
rand_keys_nodes[-1],
|
||||
)
|
||||
|
||||
node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(None, 0))
|
||||
node_attr_func = jax.vmap(genome.node_gene.new_attrs, in_axes=(None, 0))
|
||||
input_attrs = node_attr_func(input_keys)
|
||||
output_attrs = node_attr_func(output_keys)
|
||||
hidden_attrs = genome.node_gene.new_custom_attrs(hidden_key)
|
||||
@@ -544,7 +641,10 @@ def init_one_hidden_node(genome, randkey):
|
||||
conns = conns.at[output_idx, 2].set(True)
|
||||
|
||||
rand_keys_conns = jax.random.split(randkey, num=len(input_idx) + len(output_idx))
|
||||
input_conn_keys, output_conn_keys = rand_keys_conns[:len(input_idx)], rand_keys_conns[len(input_idx):]
|
||||
input_conn_keys, output_conn_keys = (
|
||||
rand_keys_conns[: len(input_idx)],
|
||||
rand_keys_conns[len(input_idx) :],
|
||||
)
|
||||
|
||||
conn_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(None, 0))
|
||||
input_conn_attrs = conn_attr_func(input_conn_keys)
|
||||
@@ -563,8 +663,12 @@ def init_dense_hideen_layer(genome, randkey, hiddens=20):
|
||||
input_size = len(input_idx)
|
||||
output_size = len(output_idx)
|
||||
|
||||
hidden_idx = jnp.arange(input_size + output_size, input_size + output_size + hiddens)
|
||||
nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan, dtype=jnp.float32)
|
||||
hidden_idx = jnp.arange(
|
||||
input_size + output_size, input_size + output_size + hiddens
|
||||
)
|
||||
nodes = jnp.full(
|
||||
(genome.max_nodes, genome.node_gene.length), jnp.nan, dtype=jnp.float32
|
||||
)
|
||||
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||
nodes = nodes.at[hidden_idx, 0].set(hidden_idx)
|
||||
@@ -572,8 +676,8 @@ def init_dense_hideen_layer(genome, randkey, hiddens=20):
|
||||
total_idx = input_size + output_size + hiddens
|
||||
rand_keys_n = jax.random.split(k1, num=total_idx)
|
||||
input_keys = rand_keys_n[:input_size]
|
||||
output_keys = rand_keys_n[input_size:input_size + output_size]
|
||||
hidden_keys = rand_keys_n[input_size + output_size:]
|
||||
output_keys = rand_keys_n[input_size : input_size + output_size]
|
||||
hidden_keys = rand_keys_n[input_size + output_size :]
|
||||
|
||||
node_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0))
|
||||
input_attrs = node_attr_func(input_keys)
|
||||
@@ -585,21 +689,31 @@ def init_dense_hideen_layer(genome, randkey, hiddens=20):
|
||||
nodes = nodes.at[hidden_idx, 1:].set(hidden_attrs)
|
||||
|
||||
total_connections = input_size * hiddens + hiddens * output_size
|
||||
conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan, dtype=jnp.float32)
|
||||
conns = jnp.full(
|
||||
(genome.max_conns, genome.conn_gene.length), jnp.nan, dtype=jnp.float32
|
||||
)
|
||||
|
||||
rand_keys_c = jax.random.split(k2, num=total_connections)
|
||||
conns_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0))
|
||||
conns_attrs = conns_attr_func(rand_keys_c)
|
||||
|
||||
input_to_hidden_ids, hidden_ids = jnp.meshgrid(input_idx, hidden_idx, indexing='ij')
|
||||
hidden_to_output_ids, output_ids = jnp.meshgrid(hidden_idx, output_idx, indexing='ij')
|
||||
input_to_hidden_ids, hidden_ids = jnp.meshgrid(input_idx, hidden_idx, indexing="ij")
|
||||
hidden_to_output_ids, output_ids = jnp.meshgrid(
|
||||
hidden_idx, output_idx, indexing="ij"
|
||||
)
|
||||
|
||||
conns = conns.at[:input_size * hiddens, 0].set(input_to_hidden_ids.flatten())
|
||||
conns = conns.at[:input_size * hiddens, 1].set(hidden_ids.flatten())
|
||||
conns = conns.at[input_size * hiddens: total_connections, 0].set(hidden_to_output_ids.flatten())
|
||||
conns = conns.at[input_size * hiddens: total_connections, 1].set(output_ids.flatten())
|
||||
conns = conns.at[:input_size * hiddens + hiddens * output_size, 2].set(True)
|
||||
conns = conns.at[:input_size * hiddens + hiddens * output_size, 3:].set(conns_attrs)
|
||||
conns = conns.at[: input_size * hiddens, 0].set(input_to_hidden_ids.flatten())
|
||||
conns = conns.at[: input_size * hiddens, 1].set(hidden_ids.flatten())
|
||||
conns = conns.at[input_size * hiddens : total_connections, 0].set(
|
||||
hidden_to_output_ids.flatten()
|
||||
)
|
||||
conns = conns.at[input_size * hiddens : total_connections, 1].set(
|
||||
output_ids.flatten()
|
||||
)
|
||||
conns = conns.at[: input_size * hiddens + hiddens * output_size, 2].set(True)
|
||||
conns = conns.at[: input_size * hiddens + hiddens * output_size, 3:].set(
|
||||
conns_attrs
|
||||
)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
@@ -615,8 +729,8 @@ def init_no_hidden_random(genome, randkey):
|
||||
|
||||
total_idx = len(input_idx) + len(output_idx)
|
||||
rand_keys_n = jax.random.split(k1, num=total_idx)
|
||||
input_keys = rand_keys_n[:len(input_idx)]
|
||||
output_keys = rand_keys_n[len(input_idx):]
|
||||
input_keys = rand_keys_n[: len(input_idx)]
|
||||
output_keys = rand_keys_n[len(input_idx) :]
|
||||
|
||||
node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0))
|
||||
input_attrs = node_attr_func(input_keys)
|
||||
|
||||
Reference in New Issue
Block a user