finish all refactoring
This commit is contained in:
@@ -2,9 +2,10 @@ import numpy as np
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import State, rank_elements, argmin_with_mask, fetch_first
|
||||
from ..genome import BaseGenome
|
||||
from .base import BaseSpecies
|
||||
|
||||
|
||||
class DefaultSpecies:
|
||||
class DefaultSpecies(BaseSpecies):
|
||||
|
||||
def __init__(self,
|
||||
genome: BaseGenome,
|
||||
@@ -18,9 +19,8 @@ class DefaultSpecies:
|
||||
genome_elitism: int = 2,
|
||||
survival_threshold: float = 0.2,
|
||||
min_species_size: int = 1,
|
||||
compatibility_threshold: float = 3.5
|
||||
compatibility_threshold: float = 3.
|
||||
):
|
||||
|
||||
self.genome = genome
|
||||
self.pop_size = pop_size
|
||||
self.species_size = species_size
|
||||
@@ -59,8 +59,12 @@ class DefaultSpecies:
|
||||
center_nodes = center_nodes.at[0].set(pop_nodes[0])
|
||||
center_conns = center_conns.at[0].set(pop_conns[0])
|
||||
|
||||
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
|
||||
|
||||
return State(
|
||||
randkey=randkey,
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
species_keys=species_keys,
|
||||
best_fitness=best_fitness,
|
||||
last_improved=last_improved,
|
||||
@@ -68,7 +72,7 @@ class DefaultSpecies:
|
||||
idx2species=idx2species,
|
||||
center_nodes=center_nodes,
|
||||
center_conns=center_conns,
|
||||
next_species_key=1, # 0 is reserved for the first species
|
||||
next_species_key=jnp.array(1), # 0 is reserved for the first species
|
||||
)
|
||||
|
||||
def ask(self, state):
|
||||
@@ -99,7 +103,7 @@ class DefaultSpecies:
|
||||
# crossover info
|
||||
winner, loser, elite_mask = self.create_crossover_pair(state, k1, spawn_number, fitness)
|
||||
|
||||
return state(randkey=k2), winner, loser, elite_mask
|
||||
return state.update(randkey=k2), winner, loser, elite_mask
|
||||
|
||||
def update_species_fitness(self, state, fitness):
|
||||
"""
|
||||
@@ -156,17 +160,17 @@ class DefaultSpecies:
|
||||
jnp.nan, # last_improved
|
||||
jnp.nan, # member_count
|
||||
-jnp.inf, # species_fitness
|
||||
jnp.full_like(center_nodes[idx], jnp.nan), # center_nodes
|
||||
jnp.full_like(center_conns[idx], jnp.nan), # center_conns
|
||||
jnp.full_like(state.center_nodes[idx], jnp.nan), # center_nodes
|
||||
jnp.full_like(state.center_conns[idx], jnp.nan), # center_conns
|
||||
), # stagnation species
|
||||
lambda: (
|
||||
species_keys[idx],
|
||||
state.species_keys[idx],
|
||||
best_fitness[idx],
|
||||
last_improved[idx],
|
||||
state.member_count[idx],
|
||||
species_fitness[idx],
|
||||
center_nodes[idx],
|
||||
center_conns[idx]
|
||||
state.center_nodes[idx],
|
||||
state.center_conns[idx]
|
||||
) # not stagnation species
|
||||
)
|
||||
|
||||
@@ -216,7 +220,7 @@ class DefaultSpecies:
|
||||
spawn_number = spawn_number.astype(jnp.int32)
|
||||
|
||||
# must control the sum of spawn_number to be equal to pop_size
|
||||
error = state.P - jnp.sum(spawn_number)
|
||||
error = self.pop_size - jnp.sum(spawn_number)
|
||||
|
||||
# add error to the first species to control the sum of spawn_number
|
||||
spawn_number = spawn_number.at[0].add(error)
|
||||
@@ -287,14 +291,14 @@ class DefaultSpecies:
|
||||
def body_func(carry):
|
||||
i, i2s, cns, ccs, o2c = carry
|
||||
|
||||
distances = o2p_distance_func(cns, ccs, state.pop_nodes, state.pop_conns)
|
||||
distances = o2p_distance_func(cns[i], ccs[i], state.pop_nodes, state.pop_conns)
|
||||
|
||||
# find the closest one
|
||||
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
|
||||
|
||||
i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i])
|
||||
cns = cns.set(i, state.pop_nodes[closest_idx])
|
||||
ccs = ccs.set(i, state.pop_conns[closest_idx])
|
||||
i2s = i2s.at[closest_idx].set(state.species_keys[i])
|
||||
cns = cns.at[i].set(state.pop_nodes[closest_idx])
|
||||
ccs = ccs.at[i].set(state.pop_conns[closest_idx])
|
||||
|
||||
# the genome with closest_idx will become the new center, thus its distance to center is 0.
|
||||
o2c = o2c.at[closest_idx].set(0)
|
||||
@@ -346,8 +350,8 @@ class DefaultSpecies:
|
||||
o2c = o2c.at[idx].set(0)
|
||||
|
||||
# update center genomes
|
||||
cns = cns.set(i, state.pop_nodes[idx])
|
||||
ccs = ccs.set(i, state.pop_conns[idx])
|
||||
cns = cns.at[i].set(state.pop_nodes[idx])
|
||||
ccs = ccs.at[i].set(state.pop_conns[idx])
|
||||
|
||||
# find the members for the new species
|
||||
i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c)
|
||||
@@ -384,7 +388,7 @@ class DefaultSpecies:
|
||||
_, idx2species, center_nodes, center_conns, species_keys, _, next_species_key = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(0, state.idx2species, state.center_nodes, center_conns, state.species_info.species_keys, o2c_distances,
|
||||
(0, state.idx2species, center_nodes, center_conns, state.species_keys, o2c_distances,
|
||||
state.next_species_key)
|
||||
)
|
||||
|
||||
@@ -401,8 +405,8 @@ class DefaultSpecies:
|
||||
def count_members(idx):
|
||||
return jax.lax.cond(
|
||||
jnp.isnan(species_keys[idx]), # if the species is not existing
|
||||
lambda _: jnp.nan, # nan
|
||||
lambda _: jnp.sum(idx2species == species_keys[idx], dtype=jnp.float32) # count members
|
||||
lambda: jnp.nan, # nan
|
||||
lambda: jnp.sum(idx2species == species_keys[idx], dtype=jnp.float32) # count members
|
||||
)
|
||||
|
||||
member_count = jax.vmap(count_members)(self.species_arange)
|
||||
@@ -422,7 +426,8 @@ class DefaultSpecies:
|
||||
"""
|
||||
The distance between two genomes
|
||||
"""
|
||||
return self.node_distance(nodes1, nodes2) + self.conn_distance(conns1, conns2)
|
||||
d = self.node_distance(nodes1, nodes2) + self.conn_distance(conns1, conns2)
|
||||
return d
|
||||
|
||||
def node_distance(self, nodes1, nodes2):
|
||||
"""
|
||||
@@ -494,18 +499,18 @@ def initialize_population(pop_size, genome):
|
||||
o_nodes[input_idx, 0] = genome.input_idx
|
||||
o_nodes[output_idx, 0] = genome.output_idx
|
||||
o_nodes[new_node_key, 0] = new_node_key # one hidden node
|
||||
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = genome.node_gene.new_attrs()
|
||||
o_nodes[new_node_key, 1:] = genome.node_gene.new_attrs() # one hidden node
|
||||
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = genome.node_gene.new_custom_attrs()
|
||||
o_nodes[new_node_key, 1:] = genome.node_gene.new_custom_attrs() # one hidden node
|
||||
|
||||
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] # input nodes to hidden
|
||||
o_conns[input_idx, 0:2] = input_conns # in key, out key
|
||||
o_conns[input_idx, 2] = True # enabled
|
||||
o_conns[input_idx, 3:] = genome.conn_gene.new_conn_attrs()
|
||||
o_conns[input_idx, 3:] = genome.conn_gene.new_custom_attrs()
|
||||
|
||||
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] # hidden to output nodes
|
||||
o_conns[output_idx, 0:2] = output_conns # in key, out key
|
||||
o_conns[output_idx, 2] = True # enabled
|
||||
o_conns[output_idx, 3:] = genome.conn_gene.new_conn_attrs()
|
||||
o_conns[output_idx, 3:] = genome.conn_gene.new_custom_attrs()
|
||||
|
||||
# repeat origin genome for P times to create population
|
||||
pop_nodes = np.tile(o_nodes, (pop_size, 1, 1))
|
||||
|
||||
Reference in New Issue
Block a user