fix bugs
This commit is contained in:
@@ -10,6 +10,9 @@ class BaseSpecies(StatefulBaseClass):
|
||||
def ask(self, state: State):
|
||||
raise NotImplementedError
|
||||
|
||||
def tell(self, state: State, fitness):
|
||||
raise NotImplementedError
|
||||
|
||||
def update_species(self, state, fitness):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -113,12 +113,23 @@ class DefaultSpecies(BaseSpecies):
|
||||
idx2species=idx2species,
|
||||
center_nodes=center_nodes,
|
||||
center_conns=center_conns,
|
||||
next_species_key=jnp.array(1), # 0 is reserved for the first species
|
||||
next_species_key=jnp.float32(1), # 0 is reserved for the first species
|
||||
generation=jnp.float32(0),
|
||||
)
|
||||
|
||||
def ask(self, state):
|
||||
return state.pop_nodes, state.pop_conns
|
||||
|
||||
def tell(self, state, fitness):
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
|
||||
state = state.update(generation=state.generation + 1, randkey=randkey)
|
||||
state, winner, loser, elite_mask = self.update_species(state, fitness)
|
||||
state = self.create_next_generation(state, winner, loser, elite_mask)
|
||||
state = self.speciate(state)
|
||||
|
||||
return state
|
||||
|
||||
def update_species(self, state, fitness):
|
||||
# update the fitness of each species
|
||||
state, species_fitness = self.update_species_fitness(state, fitness)
|
||||
@@ -619,3 +630,43 @@ class DefaultSpecies(BaseSpecies):
|
||||
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
|
||||
|
||||
return val
|
||||
|
||||
def create_next_generation(self, state, winner, loser, elite_mask):
|
||||
|
||||
# find next node key
|
||||
all_nodes_keys = state.pop_nodes[:, :, 0]
|
||||
max_node_key = jnp.max(all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0)
|
||||
next_node_key = max_node_key + 1
|
||||
new_node_keys = jnp.arange(self.pop_size) + next_node_key
|
||||
|
||||
# prepare random keys
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
crossover_randkeys = jax.random.split(k1, self.pop_size)
|
||||
mutate_randkeys = jax.random.split(k2, self.pop_size)
|
||||
|
||||
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
|
||||
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
|
||||
|
||||
# batch crossover
|
||||
n_nodes, n_conns = jax.vmap(
|
||||
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
|
||||
)(
|
||||
state, crossover_randkeys, wpn, wpc, lpn, lpc
|
||||
) # new_nodes, new_conns
|
||||
|
||||
# batch mutation
|
||||
m_n_nodes, m_n_conns = jax.vmap(
|
||||
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
|
||||
)(
|
||||
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
|
||||
) # mutated_new_nodes, mutated_new_conns
|
||||
|
||||
# elitism don't mutate
|
||||
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
|
||||
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
|
||||
|
||||
return state.update(
|
||||
randkey=randkey,
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
)
|
||||
Reference in New Issue
Block a user