modify NEAT package; successfully run xor example

This commit is contained in:
root
2024-07-11 10:10:16 +08:00
parent 52d5f046d3
commit 4a631f9464
14 changed files with 420 additions and 502 deletions

View File

@@ -14,13 +14,11 @@ class BaseAlgorithm(StatefulBaseClass):
"""transform the genome into a neural network"""
raise NotImplementedError
def restore(self, state, transformed):
raise NotImplementedError
def forward(self, state, transformed, inputs):
raise NotImplementedError
def update_by_batch(self, state, batch_input, transformed):
def show_details(self, state: State, fitness):
"""Visualize the running details of the algorithm"""
raise NotImplementedError
@property
@@ -30,15 +28,3 @@ class BaseAlgorithm(StatefulBaseClass):
@property
def num_outputs(self):
raise NotImplementedError
@property
def pop_size(self):
raise NotImplementedError
def member_count(self, state: State):
# to analysis the species
raise NotImplementedError
def generation(self, state: State):
# to analysis the algorithm
raise NotImplementedError

View File

@@ -1,40 +1,93 @@
from tensorneat.common import State
import jax
from jax import vmap, numpy as jnp
import numpy as np
from .species import SpeciesController
from .. import BaseAlgorithm
from .species import *
from tensorneat.common import State
from tensorneat.genome import BaseGenome
class NEAT(BaseAlgorithm):
def __init__(
self,
species: BaseSpecies,
genome: BaseGenome,
pop_size: int,
species_size: int = 10,
max_stagnation: int = 15,
species_elitism: int = 2,
spawn_number_change_rate: float = 0.5,
genome_elitism: int = 2,
survival_threshold: float = 0.2,
min_species_size: int = 1,
compatibility_threshold: float = 3.0,
species_fitness_func: callable = jnp.max,
):
self.species = species
self.genome = species.genome
self.genome = genome
self.pop_size = pop_size
self.species_controller = SpeciesController(
pop_size,
species_size,
max_stagnation,
species_elitism,
spawn_number_change_rate,
genome_elitism,
survival_threshold,
min_species_size,
compatibility_threshold,
species_fitness_func,
)
def setup(self, state=State()):
state = self.species.setup(state)
# setup state
state = self.genome.setup(state)
k1, randkey = jax.random.split(state.randkey, 2)
# initialize the population
initialize_keys = jax.random.split(k1, self.pop_size)
pop_nodes, pop_conns = vmap(self.genome.initialize, in_axes=(None, 0))(
state, initialize_keys
)
state = state.register(
pop_nodes=pop_nodes,
pop_conns=pop_conns,
generation=jnp.float32(0),
)
# initialize species state
state = self.species_controller.setup(state, pop_nodes[0], pop_conns[0])
return state.update(randkey=randkey)
def ask(self, state):
return state.pop_nodes, state.pop_conns
def tell(self, state, fitness):
state = state.update(generation=state.generation + 1)
# tell fitness to species controller
state, winner, loser, elite_mask = self.species_controller.update_species(
state,
fitness,
)
# create next population
state = self._create_next_generation(state, winner, loser, elite_mask)
# speciate the next population
state = self.species_controller.speciate(state, self.genome.execute_distance)
return state
def ask(self, state: State):
return self.species.ask(state)
def tell(self, state: State, fitness):
return self.species.tell(state, fitness)
def transform(self, state, individual):
"""transform the genome into a neural network"""
nodes, conns = individual
return self.genome.transform(state, nodes, conns)
def restore(self, state, transformed):
return self.genome.restore(state, transformed)
def forward(self, state, transformed, inputs):
return self.genome.forward(state, transformed, inputs)
def update_by_batch(self, state, batch_input, transformed):
return self.genome.update_by_batch(state, batch_input, transformed)
@property
def num_inputs(self):
return self.genome.num_inputs
@@ -43,13 +96,70 @@ class NEAT(BaseAlgorithm):
def num_outputs(self):
return self.genome.num_outputs
@property
def pop_size(self):
return self.species.pop_size
def _create_next_generation(self, state, winner, loser, elite_mask):
def member_count(self, state: State):
return state.member_count
# find next node key for mutation
all_nodes_keys = state.pop_nodes[:, :, 0]
max_node_key = jnp.max(
all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0
)
next_node_key = max_node_key + 1
new_node_keys = jnp.arange(self.pop_size) + next_node_key
def generation(self, state: State):
# to analysis the algorithm
return state.generation
# prepare random keys
k1, k2, randkey = jax.random.split(state.randkey, 3)
crossover_randkeys = jax.random.split(k1, self.pop_size)
mutate_randkeys = jax.random.split(k2, self.pop_size)
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
# batch crossover
n_nodes, n_conns = vmap(
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
)(
state, crossover_randkeys, wpn, wpc, lpn, lpc
) # new_nodes, new_conns
# batch mutation
m_n_nodes, m_n_conns = vmap(
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
)(
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
) # mutated_new_nodes, mutated_new_conns
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
return state.update(
randkey=randkey,
pop_nodes=pop_nodes,
pop_conns=pop_conns,
)
def show_details(self, state, fitness):
member_count = jax.device_get(state.species.member_count)
species_sizes = [int(i) for i in member_count if i > 0]
pop_nodes, pop_conns = jax.device_get([state.pop_nodes, state.pop_conns])
nodes_cnt = (~np.isnan(pop_nodes[:, :, 0])).sum(axis=1) # (P,)
conns_cnt = (~np.isnan(pop_conns[:, :, 0])).sum(axis=1) # (P,)
max_node_cnt, min_node_cnt, mean_node_cnt = (
max(nodes_cnt),
min(nodes_cnt),
np.mean(nodes_cnt),
)
max_conn_cnt, min_conn_cnt, mean_conn_cnt = (
max(conns_cnt),
min(conns_cnt),
np.mean(conns_cnt),
)
print(
f"\tnode counts: max: {max_node_cnt}, min: {min_node_cnt}, mean: {mean_node_cnt:.2f}\n",
f"\tconn counts: max: {max_conn_cnt}, min: {min_conn_cnt}, mean: {mean_conn_cnt:.2f}\n",
f"\tspecies: {len(species_sizes)}, {species_sizes}\n",
)

View File

@@ -1,54 +1,35 @@
import jax, jax.numpy as jnp
from typing import Callable
import jax
from jax import vmap, numpy as jnp
import numpy as np
from .base import BaseSpecies
from tensorneat.common import (
State,
StatefulBaseClass,
rank_elements,
argmin_with_mask,
fetch_first,
)
from tensorneat.genome.utils import (
extract_conn_attrs,
extract_node_attrs,
)
from tensorneat.genome import BaseGenome
"""
Core procedures of NEAT algorithm, contains the following steps:
1. Update the fitness of each species;
2. Decide which species will be stagnation;
3. Decide the number of members of each species in the next generation;
4. Choice the crossover pair for each species;
5. Divided the whole new population into different species;
This class use tensor operation to imitate the behavior of NEAT algorithm which implemented in NEAT-python.
The code may be hard to understand. Fortunately, we don't need to overwrite it in most cases.
"""
class DefaultSpecies(BaseSpecies):
class SpeciesController(StatefulBaseClass):
def __init__(
self,
genome: BaseGenome,
pop_size,
species_size,
compatibility_disjoint: float = 1.0,
compatibility_weight: float = 0.4,
max_stagnation: int = 15,
species_elitism: int = 2,
spawn_number_change_rate: float = 0.5,
genome_elitism: int = 2,
survival_threshold: float = 0.2,
min_species_size: int = 1,
compatibility_threshold: float = 3.0,
max_stagnation,
species_elitism,
spawn_number_change_rate,
genome_elitism,
survival_threshold,
min_species_size,
compatibility_threshold,
species_fitness_func,
):
self.genome = genome
self.pop_size = pop_size
self.species_size = species_size
self.compatibility_disjoint = compatibility_disjoint
self.compatibility_weight = compatibility_weight
self.species_arange = np.arange(self.species_size)
self.max_stagnation = max_stagnation
self.species_elitism = species_elitism
self.spawn_number_change_rate = spawn_number_change_rate
@@ -56,42 +37,33 @@ class DefaultSpecies(BaseSpecies):
self.survival_threshold = survival_threshold
self.min_species_size = min_species_size
self.compatibility_threshold = compatibility_threshold
self.species_fitness_func = species_fitness_func
self.species_arange = jnp.arange(self.species_size)
def setup(self, state, first_nodes, first_conns):
# the unique index (primary key) for each species
species_keys = jnp.full((self.species_size,), jnp.nan)
def setup(self, state=State()):
state = self.genome.setup(state)
k1, randkey = jax.random.split(state.randkey, 2)
# the best fitness of each species
best_fitness = jnp.full((self.species_size,), jnp.nan)
# initialize the population
initialize_keys = jax.random.split(randkey, self.pop_size)
pop_nodes, pop_conns = jax.vmap(self.genome.initialize, in_axes=(None, 0))(
state, initialize_keys
)
# the last 1 that the species improved
last_improved = jnp.full((self.species_size,), jnp.nan)
species_keys = jnp.full(
(self.species_size,), jnp.nan
) # the unique index (primary key) for each species
best_fitness = jnp.full(
(self.species_size,), jnp.nan
) # the best fitness of each species
last_improved = jnp.full(
(self.species_size,), jnp.nan
) # the last 1 that the species improved
member_count = jnp.full(
(self.species_size,), jnp.nan
) # the number of members of each species
idx2species = jnp.zeros(self.pop_size) # the species index of each individual
# the number of members of each species
member_count = jnp.full((self.species_size,), jnp.nan)
# the species index of each individual
idx2species = jnp.zeros(self.pop_size)
# nodes for each center genome of each species
center_nodes = jnp.full(
(self.species_size, self.genome.max_nodes, self.genome.node_gene.length),
(self.species_size, *first_nodes.shape),
jnp.nan,
)
# connections for each center genome of each species
center_conns = jnp.full(
(self.species_size, self.genome.max_conns, self.genome.conn_gene.length),
(self.species_size, *first_conns.shape),
jnp.nan,
)
@@ -99,16 +71,10 @@ class DefaultSpecies(BaseSpecies):
best_fitness = best_fitness.at[0].set(-jnp.inf)
last_improved = last_improved.at[0].set(0)
member_count = member_count.at[0].set(self.pop_size)
center_nodes = center_nodes.at[0].set(pop_nodes[0])
center_conns = center_conns.at[0].set(pop_conns[0])
center_nodes = center_nodes.at[0].set(first_nodes)
center_conns = center_conns.at[0].set(first_conns)
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
state = state.update(randkey=randkey)
return state.register(
pop_nodes=pop_nodes,
pop_conns=pop_conns,
species_state = State(
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
@@ -117,53 +83,50 @@ class DefaultSpecies(BaseSpecies):
center_nodes=center_nodes,
center_conns=center_conns,
next_species_key=jnp.float32(1), # 0 is reserved for the first species
generation=jnp.float32(0),
)
def ask(self, state):
return state.pop_nodes, state.pop_conns
def tell(self, state, fitness):
k1, k2, randkey = jax.random.split(state.randkey, 3)
state = state.update(generation=state.generation + 1, randkey=randkey)
state, winner, loser, elite_mask = self.update_species(state, fitness)
state = self.create_next_generation(state, winner, loser, elite_mask)
state = self.speciate(state)
return state
return state.register(species=species_state)
def update_species(self, state, fitness):
species_state = state.species
# update the fitness of each species
state, species_fitness = self.update_species_fitness(state, fitness)
species_fitness = self._update_species_fitness(species_state, fitness)
# stagnation species
state, species_fitness = self.stagnation(state, species_fitness)
species_state, species_fitness = self._stagnation(
species_state, species_fitness, state.generation
)
# sort species_info by their fitness. (also push nan to the end)
sort_indices = jnp.argsort(species_fitness)[::-1]
sort_indices = jnp.argsort(species_fitness)[::-1] # fitness from high to low
state = state.update(
species_keys=state.species_keys[sort_indices],
best_fitness=state.best_fitness[sort_indices],
last_improved=state.last_improved[sort_indices],
member_count=state.member_count[sort_indices],
center_nodes=state.center_nodes[sort_indices],
center_conns=state.center_conns[sort_indices],
species_state = species_state.update(
species_keys=species_state.species_keys[sort_indices],
best_fitness=species_state.best_fitness[sort_indices],
last_improved=species_state.last_improved[sort_indices],
member_count=species_state.member_count[sort_indices],
center_nodes=species_state.center_nodes[sort_indices],
center_conns=species_state.center_conns[sort_indices],
)
# decide the number of members of each species by their fitness
state, spawn_number = self.cal_spawn_numbers(state)
spawn_number = self._cal_spawn_numbers(species_state)
k1, k2 = jax.random.split(state.randkey)
# crossover info
state, winner, loser, elite_mask = self.create_crossover_pair(
state, spawn_number, fitness
winner, loser, elite_mask = self._create_crossover_pair(
species_state, k1, spawn_number, fitness
)
return state.update(randkey=k2), winner, loser, elite_mask
return (
state.update(randkey=k2, species=species_state),
winner,
loser,
elite_mask,
)
def update_species_fitness(self, state, fitness):
def _update_species_fitness(self, species_state, fitness):
"""
obtain the fitness of the species by the fitness of each individual.
use max criterion.
@@ -171,14 +134,16 @@ class DefaultSpecies(BaseSpecies):
def aux_func(idx):
s_fitness = jnp.where(
state.idx2species == state.species_keys[idx], fitness, -jnp.inf
species_state.idx2species == species_state.species_keys[idx],
fitness,
-jnp.inf,
)
val = jnp.max(s_fitness)
val = self.species_fitness_func(s_fitness)
return val
return state, jax.vmap(aux_func)(self.species_arange)
return vmap(aux_func)(self.species_arange)
def stagnation(self, state, species_fitness):
def _stagnation(self, species_state, species_fitness, generation):
"""
stagnation species.
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
@@ -187,28 +152,36 @@ class DefaultSpecies(BaseSpecies):
def check_stagnation(idx):
# determine whether the species stagnation
st = (
species_fitness[idx] <= state.best_fitness[idx]
) & ( # not better than the best fitness of the species
state.generation - state.last_improved[idx] > self.max_stagnation
) # for a long time
# not better than the best fitness of the species
# for a long time
st = (species_fitness[idx] <= species_state.best_fitness[idx]) & (
generation - species_state.last_improved[idx] > self.max_stagnation
)
# update last_improved and best_fitness
# whether better than the best fitness of the species
li, bf = jax.lax.cond(
species_fitness[idx] > state.best_fitness[idx],
lambda: (state.generation, species_fitness[idx]), # update
species_fitness[idx] > species_state.best_fitness[idx],
lambda: (generation, species_fitness[idx]), # update
lambda: (
state.last_improved[idx],
state.best_fitness[idx],
species_state.last_improved[idx],
species_state.best_fitness[idx],
), # not update
)
return st, bf, li
spe_st, best_fitness, last_improved = jax.vmap(check_stagnation)(
spe_st, best_fitness, last_improved = vmap(check_stagnation)(
self.species_arange
)
# update species state
species_state = species_state.update(
best_fitness=best_fitness,
last_improved=last_improved,
)
# elite species will not be stagnation
species_rank = rank_elements(species_fitness)
spe_st = jnp.where(
@@ -224,18 +197,18 @@ class DefaultSpecies(BaseSpecies):
jnp.nan, # best_fitness
jnp.nan, # last_improved
jnp.nan, # member_count
jnp.full_like(species_state.center_nodes[idx], jnp.nan),
jnp.full_like(species_state.center_conns[idx], jnp.nan),
-jnp.inf, # species_fitness
jnp.full_like(state.center_nodes[idx], jnp.nan), # center_nodes
jnp.full_like(state.center_conns[idx], jnp.nan), # center_conns
), # stagnation species
lambda: (
state.species_keys[idx],
best_fitness[idx],
last_improved[idx],
state.member_count[idx],
species_state.species_keys[idx],
species_state.best_fitness[idx],
species_state.last_improved[idx],
species_state.member_count[idx],
species_state.center_nodes[idx],
species_state.center_conns[idx],
species_fitness[idx],
state.center_nodes[idx],
state.center_conns[idx],
), # not stagnation species
)
@@ -244,13 +217,13 @@ class DefaultSpecies(BaseSpecies):
best_fitness,
last_improved,
member_count,
species_fitness,
center_nodes,
center_conns,
) = jax.vmap(update_func)(self.species_arange)
species_fitness,
) = vmap(update_func)(self.species_arange)
return (
state.update(
species_state.update(
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
@@ -261,7 +234,7 @@ class DefaultSpecies(BaseSpecies):
species_fitness,
)
def cal_spawn_numbers(self, state):
def _cal_spawn_numbers(self, species_state):
"""
decide the number of members of each species by their fitness rank.
the species with higher fitness will have more members
@@ -269,7 +242,7 @@ class DefaultSpecies(BaseSpecies):
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
"""
species_keys = state.species_keys
species_keys = species_state.species_keys
is_species_valid = ~jnp.isnan(species_keys)
valid_species_num = jnp.sum(is_species_valid)
@@ -288,7 +261,7 @@ class DefaultSpecies(BaseSpecies):
) # calculate member
# Avoid too much variation of numbers for a species
previous_size = state.member_count
previous_size = species_state.member_count
spawn_number = (
previous_size
+ (target_spawn_number - previous_size) * self.spawn_number_change_rate
@@ -301,14 +274,17 @@ class DefaultSpecies(BaseSpecies):
# add error to the first species to control the sum of spawn_number
spawn_number = spawn_number.at[0].add(error)
return state, spawn_number
return spawn_number
def create_crossover_pair(self, state, spawn_number, fitness):
def _create_crossover_pair(self, species_state, randkey, spawn_number, fitness):
s_idx = self.species_arange
p_idx = jnp.arange(self.pop_size)
def aux_func(key, idx):
members = state.idx2species == state.species_keys[idx]
# choose parents from the in the same species
# key -> randkey, idx -> the idx of current species
members = species_state.idx2species == species_state.species_keys[idx]
members_num = jnp.sum(members)
members_fitness = jnp.where(members, fitness, -jnp.inf)
@@ -333,11 +309,16 @@ class DefaultSpecies(BaseSpecies):
elite = jnp.where(p_idx < self.genome_elitism, True, False)
return fa, ma, elite
randkey_, randkey = jax.random.split(state.randkey)
fas, mas, elites = jax.vmap(aux_func)(
jax.random.split(randkey_, self.species_size), s_idx
# choose parents to crossover in each species
# fas, mas, elites: (self.species_size, self.pop_size)
# fas -> father indices, mas -> mother indices, elites -> whether elite or not
fas, mas, elites = vmap(aux_func)(
jax.random.split(randkey, self.species_size), s_idx
)
# merge choosen parents from each species into one array
# winner, loser, elite_mask: (self.pop_size)
# winner -> winner indices, loser -> loser indices, elite_mask -> whether elite or not
spawn_number_cum = jnp.cumsum(spawn_number)
def aux_func(idx):
@@ -351,18 +332,18 @@ class DefaultSpecies(BaseSpecies):
elites[loc, idx_in_species],
)
part1, part2, elite_mask = jax.vmap(aux_func)(p_idx)
part1, part2, elite_mask = vmap(aux_func)(p_idx)
is_part1_win = fitness[part1] >= fitness[part2]
winner = jnp.where(is_part1_win, part1, part2)
loser = jnp.where(is_part1_win, part2, part1)
return state.update(randkey=randkey), winner, loser, elite_mask
return winner, loser, elite_mask
def speciate(self, state):
def speciate(self, state, genome_distance_func: Callable):
# prepare distance functions
o2p_distance_func = jax.vmap(
self.distance, in_axes=(None, None, None, 0, 0)
o2p_distance_func = vmap(
genome_distance_func, in_axes=(None, None, None, 0, 0)
) # one to population
# idx to specie key
@@ -379,7 +360,7 @@ class DefaultSpecies(BaseSpecies):
i, i2s, cns, ccs, o2c = carry
return (i < self.species_size) & (
~jnp.isnan(state.species_keys[i])
~jnp.isnan(state.species.species_keys[i])
) # current species is existing
def body_func(carry):
@@ -392,7 +373,7 @@ class DefaultSpecies(BaseSpecies):
# find the closest one
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
i2s = i2s.at[closest_idx].set(state.species_keys[i])
i2s = i2s.at[closest_idx].set(state.species.species_keys[i])
cns = cns.at[i].set(state.pop_nodes[closest_idx])
ccs = ccs.at[i].set(state.pop_conns[closest_idx])
@@ -404,13 +385,21 @@ class DefaultSpecies(BaseSpecies):
_, idx2species, center_nodes, center_conns, o2c_distances = jax.lax.while_loop(
cond_func,
body_func,
(0, idx2species, state.center_nodes, state.center_conns, o2c_distances),
(
0,
idx2species,
state.species.center_nodes,
state.species.center_conns,
o2c_distances,
),
)
state = state.update(
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
species=state.species.update(
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
),
)
# part 2: assign members to each species
@@ -500,12 +489,12 @@ class DefaultSpecies(BaseSpecies):
body_func,
(
0,
state.idx2species,
state.species.idx2species,
center_nodes,
center_conns,
state.species_keys,
state.species.species_keys,
o2c_distances,
state.next_species_key,
state.species.next_species_key,
),
)
@@ -514,10 +503,10 @@ class DefaultSpecies(BaseSpecies):
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
# complete info of species which is created in this generation
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.best_fitness)
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.best_fitness)
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species.best_fitness)
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species.best_fitness)
last_improved = jnp.where(
new_created_mask, state.generation, state.last_improved
new_created_mask, state.generation, state.species.last_improved
)
# update members count
@@ -530,9 +519,9 @@ class DefaultSpecies(BaseSpecies):
), # count members
)
member_count = jax.vmap(count_members)(self.species_arange)
member_count = vmap(count_members)(self.species_arange)
return state.update(
species_state = state.species.update(
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
@@ -543,135 +532,6 @@ class DefaultSpecies(BaseSpecies):
next_species_key=next_species_key,
)
def distance(self, state, nodes1, conns1, nodes2, conns2):
"""
The distance between two genomes
"""
d = self.node_distance(state, nodes1, nodes2) + self.conn_distance(
state, conns1, conns2
)
return d
def node_distance(self, state, nodes1, nodes2):
"""
The distance of the nodes part for two genomes
"""
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
# align homologous nodes
# this process is similar to np.intersect1d.
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = jnp.concatenate(
[nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0
) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
# flag location of homologous nodes
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
# calculate the count of non_homologous of two genomes
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
# calculate the distance of homologous nodes
fr_attrs = jax.vmap(extract_node_attrs)(fr)
sr_attrs = jax.vmap(extract_node_attrs)(sr)
hnd = jax.vmap(self.genome.node_gene.distance, in_axes=(None, 0, 0))(
state, fr_attrs, sr_attrs
) # homologous node distance
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
val = (
non_homologous_cnt * self.compatibility_disjoint
+ homologous_distance * self.compatibility_weight
)
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
return val
def conn_distance(self, state, conns1, conns2):
"""
The distance of the conns part for two genomes
"""
con_cnt1 = jnp.sum(~jnp.isnan(conns1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(conns2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
cons = jnp.concatenate((conns1, conns2), axis=0)
keys = cons[:, :2]
sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate(
[cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0
) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
fr_attrs = jax.vmap(extract_conn_attrs)(fr)
sr_attrs = jax.vmap(extract_conn_attrs)(sr)
hcd = jax.vmap(self.genome.conn_gene.distance, in_axes=(None, 0, 0))(
state, fr_attrs, sr_attrs
) # homologous connection distance
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask)
val = (
non_homologous_cnt * self.compatibility_disjoint
+ homologous_distance * self.compatibility_weight
)
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
return val
def create_next_generation(self, state, winner, loser, elite_mask):
# find next node key
all_nodes_keys = state.pop_nodes[:, :, 0]
max_node_key = jnp.max(
all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0
)
next_node_key = max_node_key + 1
new_node_keys = jnp.arange(self.pop_size) + next_node_key
# prepare random keys
k1, k2, randkey = jax.random.split(state.randkey, 3)
crossover_randkeys = jax.random.split(k1, self.pop_size)
mutate_randkeys = jax.random.split(k2, self.pop_size)
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
# batch crossover
n_nodes, n_conns = jax.vmap(
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
)(
state, crossover_randkeys, wpn, wpc, lpn, lpc
) # new_nodes, new_conns
# batch mutation
m_n_nodes, m_n_conns = jax.vmap(
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
)(
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
) # mutated_new_nodes, mutated_new_conns
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
return state.update(
randkey=randkey,
pop_nodes=pop_nodes,
pop_conns=pop_conns,
species=species_state,
)

View File

@@ -1,2 +0,0 @@
from .base import BaseSpecies
from .default import DefaultSpecies

View File

@@ -1,20 +0,0 @@
from tensorneat.common import State, StatefulBaseClass
from tensorneat.genome import BaseGenome
class BaseSpecies(StatefulBaseClass):
genome: BaseGenome
pop_size: int
species_size: int
def ask(self, state: State):
raise NotImplementedError
def tell(self, state: State, fitness):
raise NotImplementedError
def update_species(self, state, fitness):
raise NotImplementedError
def speciate(self, state):
raise NotImplementedError