add gene type RNN
This commit is contained in:
@@ -7,8 +7,8 @@ from .utils import rank_elements, fetch_first
|
||||
from .genome import create_mutate, create_distance, crossover
|
||||
from .gene import BaseGene
|
||||
|
||||
def create_tell(config, gene_type: Type[BaseGene]):
|
||||
|
||||
def create_tell(config, gene_type: Type[BaseGene]):
|
||||
mutate = create_mutate(config, gene_type)
|
||||
distance = create_distance(config, gene_type)
|
||||
|
||||
@@ -36,7 +36,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
||||
|
||||
return state, winner, loser, elite_mask
|
||||
|
||||
|
||||
def update_species_fitness(state, fitness):
|
||||
"""
|
||||
obtain the fitness of the species by the fitness of each individual.
|
||||
@@ -51,7 +50,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
||||
|
||||
return vmap(aux_func)(jnp.arange(state.species_info.shape[0]))
|
||||
|
||||
|
||||
def stagnation(state, species_fitness):
|
||||
"""
|
||||
stagnation species.
|
||||
@@ -88,7 +86,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
||||
|
||||
return state, species_fitness
|
||||
|
||||
|
||||
def cal_spawn_numbers(state):
|
||||
"""
|
||||
decide the number of members of each species by their fitness rank.
|
||||
@@ -106,7 +103,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
||||
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
|
||||
|
||||
target_spawn_number = jnp.floor(spawn_number_rate * state.P) # calculate member
|
||||
# jax.debug.print("denominator: {}, spawn_number_rate: {}, target_spawn_number: {}", denominator, spawn_number_rate, target_spawn_number)
|
||||
|
||||
# Avoid too much variation of numbers in a species
|
||||
previous_size = state.species_info[:, 3].astype(jnp.int32)
|
||||
@@ -118,11 +114,11 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
||||
|
||||
# must control the sum of spawn_number to be equal to pop_size
|
||||
error = state.P - jnp.sum(spawn_number)
|
||||
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
|
||||
spawn_number = spawn_number.at[0].add(
|
||||
error) # add error to the first species to control the sum of spawn_number
|
||||
|
||||
return spawn_number
|
||||
|
||||
|
||||
def create_crossover_pair(state, randkey, spawn_number, fitness):
|
||||
species_size = state.species_info.shape[0]
|
||||
pop_size = fitness.shape[0]
|
||||
@@ -238,8 +234,8 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
||||
return i + 1, i2s, cn, cc, o2c
|
||||
|
||||
_, idx2specie, center_nodes, center_conns, o2c_distances = \
|
||||
jax.lax.while_loop(cond_func, body_func, (0, idx2specie, state.center_nodes, state.center_conns, o2c_distances))
|
||||
|
||||
jax.lax.while_loop(cond_func, body_func,
|
||||
(0, idx2specie, state.center_nodes, state.center_conns, o2c_distances))
|
||||
|
||||
# part 2: assign members to each species
|
||||
def cond_func(carry):
|
||||
@@ -331,7 +327,7 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
||||
species_info = species_info.at[:, 3].set(species_member_counts)
|
||||
|
||||
return state.update(
|
||||
idx2specie=idx2specie,
|
||||
idx2species=idx2specie,
|
||||
center_nodes=center_nodes,
|
||||
center_conns=center_conns,
|
||||
species_info=species_info,
|
||||
@@ -358,11 +354,10 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
||||
|
||||
return state
|
||||
|
||||
|
||||
return tell
|
||||
|
||||
|
||||
def argmin_with_mask(arr, mask):
|
||||
masked_arr = jnp.where(mask, arr, jnp.inf)
|
||||
min_idx = jnp.argmin(masked_arr)
|
||||
return min_idx
|
||||
return min_idx
|
||||
|
||||
Reference in New Issue
Block a user