modify NEAT package; successfully run xor example
This commit is contained in:
@@ -1,40 +1,93 @@
|
||||
from tensorneat.common import State
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from .species import SpeciesController
|
||||
from .. import BaseAlgorithm
|
||||
from .species import *
|
||||
from tensorneat.common import State
|
||||
from tensorneat.genome import BaseGenome
|
||||
|
||||
|
||||
class NEAT(BaseAlgorithm):
|
||||
def __init__(
|
||||
self,
|
||||
species: BaseSpecies,
|
||||
genome: BaseGenome,
|
||||
pop_size: int,
|
||||
species_size: int = 10,
|
||||
max_stagnation: int = 15,
|
||||
species_elitism: int = 2,
|
||||
spawn_number_change_rate: float = 0.5,
|
||||
genome_elitism: int = 2,
|
||||
survival_threshold: float = 0.2,
|
||||
min_species_size: int = 1,
|
||||
compatibility_threshold: float = 3.0,
|
||||
species_fitness_func: callable = jnp.max,
|
||||
):
|
||||
self.species = species
|
||||
self.genome = species.genome
|
||||
self.genome = genome
|
||||
self.pop_size = pop_size
|
||||
self.species_controller = SpeciesController(
|
||||
pop_size,
|
||||
species_size,
|
||||
max_stagnation,
|
||||
species_elitism,
|
||||
spawn_number_change_rate,
|
||||
genome_elitism,
|
||||
survival_threshold,
|
||||
min_species_size,
|
||||
compatibility_threshold,
|
||||
species_fitness_func,
|
||||
)
|
||||
|
||||
def setup(self, state=State()):
|
||||
state = self.species.setup(state)
|
||||
# setup state
|
||||
state = self.genome.setup(state)
|
||||
|
||||
k1, randkey = jax.random.split(state.randkey, 2)
|
||||
|
||||
# initialize the population
|
||||
initialize_keys = jax.random.split(k1, self.pop_size)
|
||||
pop_nodes, pop_conns = vmap(self.genome.initialize, in_axes=(None, 0))(
|
||||
state, initialize_keys
|
||||
)
|
||||
|
||||
state = state.register(
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
generation=jnp.float32(0),
|
||||
)
|
||||
|
||||
# initialize species state
|
||||
state = self.species_controller.setup(state, pop_nodes[0], pop_conns[0])
|
||||
|
||||
return state.update(randkey=randkey)
|
||||
|
||||
def ask(self, state):
|
||||
return state.pop_nodes, state.pop_conns
|
||||
|
||||
def tell(self, state, fitness):
|
||||
state = state.update(generation=state.generation + 1)
|
||||
|
||||
# tell fitness to species controller
|
||||
state, winner, loser, elite_mask = self.species_controller.update_species(
|
||||
state,
|
||||
fitness,
|
||||
)
|
||||
|
||||
# create next population
|
||||
state = self._create_next_generation(state, winner, loser, elite_mask)
|
||||
|
||||
# speciate the next population
|
||||
state = self.species_controller.speciate(state, self.genome.execute_distance)
|
||||
|
||||
return state
|
||||
|
||||
def ask(self, state: State):
|
||||
return self.species.ask(state)
|
||||
|
||||
def tell(self, state: State, fitness):
|
||||
return self.species.tell(state, fitness)
|
||||
|
||||
def transform(self, state, individual):
|
||||
"""transform the genome into a neural network"""
|
||||
nodes, conns = individual
|
||||
return self.genome.transform(state, nodes, conns)
|
||||
|
||||
def restore(self, state, transformed):
|
||||
return self.genome.restore(state, transformed)
|
||||
|
||||
def forward(self, state, transformed, inputs):
|
||||
return self.genome.forward(state, transformed, inputs)
|
||||
|
||||
def update_by_batch(self, state, batch_input, transformed):
|
||||
return self.genome.update_by_batch(state, batch_input, transformed)
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
return self.genome.num_inputs
|
||||
@@ -43,13 +96,70 @@ class NEAT(BaseAlgorithm):
|
||||
def num_outputs(self):
|
||||
return self.genome.num_outputs
|
||||
|
||||
@property
|
||||
def pop_size(self):
|
||||
return self.species.pop_size
|
||||
def _create_next_generation(self, state, winner, loser, elite_mask):
|
||||
|
||||
def member_count(self, state: State):
|
||||
return state.member_count
|
||||
# find next node key for mutation
|
||||
all_nodes_keys = state.pop_nodes[:, :, 0]
|
||||
max_node_key = jnp.max(
|
||||
all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0
|
||||
)
|
||||
next_node_key = max_node_key + 1
|
||||
new_node_keys = jnp.arange(self.pop_size) + next_node_key
|
||||
|
||||
def generation(self, state: State):
|
||||
# to analysis the algorithm
|
||||
return state.generation
|
||||
# prepare random keys
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
crossover_randkeys = jax.random.split(k1, self.pop_size)
|
||||
mutate_randkeys = jax.random.split(k2, self.pop_size)
|
||||
|
||||
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
|
||||
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
|
||||
|
||||
# batch crossover
|
||||
n_nodes, n_conns = vmap(
|
||||
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
|
||||
)(
|
||||
state, crossover_randkeys, wpn, wpc, lpn, lpc
|
||||
) # new_nodes, new_conns
|
||||
|
||||
# batch mutation
|
||||
m_n_nodes, m_n_conns = vmap(
|
||||
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
|
||||
)(
|
||||
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
|
||||
) # mutated_new_nodes, mutated_new_conns
|
||||
|
||||
# elitism don't mutate
|
||||
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
|
||||
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
|
||||
|
||||
return state.update(
|
||||
randkey=randkey,
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
)
|
||||
|
||||
def show_details(self, state, fitness):
|
||||
member_count = jax.device_get(state.species.member_count)
|
||||
species_sizes = [int(i) for i in member_count if i > 0]
|
||||
|
||||
pop_nodes, pop_conns = jax.device_get([state.pop_nodes, state.pop_conns])
|
||||
nodes_cnt = (~np.isnan(pop_nodes[:, :, 0])).sum(axis=1) # (P,)
|
||||
conns_cnt = (~np.isnan(pop_conns[:, :, 0])).sum(axis=1) # (P,)
|
||||
|
||||
max_node_cnt, min_node_cnt, mean_node_cnt = (
|
||||
max(nodes_cnt),
|
||||
min(nodes_cnt),
|
||||
np.mean(nodes_cnt),
|
||||
)
|
||||
|
||||
max_conn_cnt, min_conn_cnt, mean_conn_cnt = (
|
||||
max(conns_cnt),
|
||||
min(conns_cnt),
|
||||
np.mean(conns_cnt),
|
||||
)
|
||||
|
||||
print(
|
||||
f"\tnode counts: max: {max_node_cnt}, min: {min_node_cnt}, mean: {mean_node_cnt:.2f}\n",
|
||||
f"\tconn counts: max: {max_conn_cnt}, min: {min_conn_cnt}, mean: {mean_conn_cnt:.2f}\n",
|
||||
f"\tspecies: {len(species_sizes)}, {species_sizes}\n",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user