This commit is contained in:
wls2002
2024-06-20 16:32:52 +08:00
parent 9f72813c35
commit 075460f896
17 changed files with 224 additions and 140 deletions

View File

@@ -10,6 +10,9 @@ class BaseSpecies(StatefulBaseClass):
def ask(self, state: State):
raise NotImplementedError
def tell(self, state: State, fitness):
raise NotImplementedError
def update_species(self, state, fitness):
raise NotImplementedError

View File

@@ -113,12 +113,23 @@ class DefaultSpecies(BaseSpecies):
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
next_species_key=jnp.array(1), # 0 is reserved for the first species
next_species_key=jnp.float32(1), # 0 is reserved for the first species
generation=jnp.float32(0),
)
def ask(self, state):
return state.pop_nodes, state.pop_conns
def tell(self, state, fitness):
k1, k2, randkey = jax.random.split(state.randkey, 3)
state = state.update(generation=state.generation + 1, randkey=randkey)
state, winner, loser, elite_mask = self.update_species(state, fitness)
state = self.create_next_generation(state, winner, loser, elite_mask)
state = self.speciate(state)
return state
def update_species(self, state, fitness):
# update the fitness of each species
state, species_fitness = self.update_species_fitness(state, fitness)
@@ -619,3 +630,43 @@ class DefaultSpecies(BaseSpecies):
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
return val
def create_next_generation(self, state, winner, loser, elite_mask):
# find next node key
all_nodes_keys = state.pop_nodes[:, :, 0]
max_node_key = jnp.max(all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0)
next_node_key = max_node_key + 1
new_node_keys = jnp.arange(self.pop_size) + next_node_key
# prepare random keys
k1, k2, randkey = jax.random.split(state.randkey, 3)
crossover_randkeys = jax.random.split(k1, self.pop_size)
mutate_randkeys = jax.random.split(k2, self.pop_size)
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
# batch crossover
n_nodes, n_conns = jax.vmap(
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
)(
state, crossover_randkeys, wpn, wpc, lpn, lpc
) # new_nodes, new_conns
# batch mutation
m_n_nodes, m_n_conns = jax.vmap(
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
)(
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
) # mutated_new_nodes, mutated_new_conns
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
return state.update(
randkey=randkey,
pop_nodes=pop_nodes,
pop_conns=pop_conns,
)