add gene type RNN

This commit is contained in:
wls2002
2023-07-19 15:43:49 +08:00
parent 0a2a9fd1be
commit a684e6584d
18 changed files with 248 additions and 129 deletions

View File

@@ -7,8 +7,8 @@ from .utils import rank_elements, fetch_first
from .genome import create_mutate, create_distance, crossover
from .gene import BaseGene
def create_tell(config, gene_type: Type[BaseGene]):
def create_tell(config, gene_type: Type[BaseGene]):
mutate = create_mutate(config, gene_type)
distance = create_distance(config, gene_type)
@@ -36,7 +36,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
return state, winner, loser, elite_mask
def update_species_fitness(state, fitness):
"""
obtain the fitness of the species by the fitness of each individual.
@@ -51,7 +50,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
return vmap(aux_func)(jnp.arange(state.species_info.shape[0]))
def stagnation(state, species_fitness):
"""
stagnation species.
@@ -88,7 +86,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
return state, species_fitness
def cal_spawn_numbers(state):
"""
decide the number of members of each species by their fitness rank.
@@ -106,7 +103,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
target_spawn_number = jnp.floor(spawn_number_rate * state.P) # calculate member
# jax.debug.print("denominator: {}, spawn_number_rate: {}, target_spawn_number: {}", denominator, spawn_number_rate, target_spawn_number)
# Avoid too much variation of numbers in a species
previous_size = state.species_info[:, 3].astype(jnp.int32)
@@ -118,11 +114,11 @@ def create_tell(config, gene_type: Type[BaseGene]):
# must control the sum of spawn_number to be equal to pop_size
error = state.P - jnp.sum(spawn_number)
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
spawn_number = spawn_number.at[0].add(
error) # add error to the first species to control the sum of spawn_number
return spawn_number
def create_crossover_pair(state, randkey, spawn_number, fitness):
species_size = state.species_info.shape[0]
pop_size = fitness.shape[0]
@@ -238,8 +234,8 @@ def create_tell(config, gene_type: Type[BaseGene]):
return i + 1, i2s, cn, cc, o2c
_, idx2specie, center_nodes, center_conns, o2c_distances = \
jax.lax.while_loop(cond_func, body_func, (0, idx2specie, state.center_nodes, state.center_conns, o2c_distances))
jax.lax.while_loop(cond_func, body_func,
(0, idx2specie, state.center_nodes, state.center_conns, o2c_distances))
# part 2: assign members to each species
def cond_func(carry):
@@ -331,7 +327,7 @@ def create_tell(config, gene_type: Type[BaseGene]):
species_info = species_info.at[:, 3].set(species_member_counts)
return state.update(
idx2specie=idx2specie,
idx2species=idx2specie,
center_nodes=center_nodes,
center_conns=center_conns,
species_info=species_info,
@@ -358,11 +354,10 @@ def create_tell(config, gene_type: Type[BaseGene]):
return state
return tell
def argmin_with_mask(arr, mask):
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx
return min_idx