complete fully stateful!
use black to format all files!
This commit is contained in:
@@ -1,10 +1,22 @@
|
||||
import numpy as np
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import State, rank_elements, argmin_with_mask, fetch_first
|
||||
from ..genome import BaseGenome
|
||||
from .base import BaseSpecies
|
||||
|
||||
|
||||
"""
|
||||
Core procedures of NEAT algorithm, contains the following steps:
|
||||
1. Update the fitness of each species;
|
||||
2. Decide which species will be stagnation;
|
||||
3. Decide the number of members of each species in the next generation;
|
||||
4. Choice the crossover pair for each species;
|
||||
5. Divided the whole new population into different species;
|
||||
|
||||
This class use tensor operation to imitate the behavior of NEAT algorithm which implemented in NEAT-python.
|
||||
The code may be hard to understand. Fortunately, we don't need to overwrite it in most cases.
|
||||
"""
|
||||
|
||||
|
||||
class DefaultSpecies(BaseSpecies):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -20,8 +32,6 @@ class DefaultSpecies(BaseSpecies):
|
||||
survival_threshold: float = 0.2,
|
||||
min_species_size: int = 1,
|
||||
compatibility_threshold: float = 3.0,
|
||||
initialize_method: str = "one_hidden_node",
|
||||
# {'one_hidden_node', 'dense_hideen_layer', 'no_hidden_random'}
|
||||
):
|
||||
self.genome = genome
|
||||
self.pop_size = pop_size
|
||||
@@ -36,15 +46,17 @@ class DefaultSpecies(BaseSpecies):
|
||||
self.survival_threshold = survival_threshold
|
||||
self.min_species_size = min_species_size
|
||||
self.compatibility_threshold = compatibility_threshold
|
||||
self.initialize_method = initialize_method
|
||||
|
||||
self.species_arange = jnp.arange(self.species_size)
|
||||
|
||||
def setup(self, state=State()):
|
||||
state = self.genome.setup(state)
|
||||
k1, randkey = jax.random.split(state.randkey, 2)
|
||||
pop_nodes, pop_conns = initialize_population(
|
||||
self.pop_size, self.genome, k1, self.initialize_method
|
||||
|
||||
# initialize the population
|
||||
initialize_keys = jax.random.split(randkey, self.pop_size)
|
||||
pop_nodes, pop_conns = jax.vmap(self.genome.initialize, in_axes=(None, 0))(
|
||||
state, initialize_keys
|
||||
)
|
||||
|
||||
species_keys = jnp.full(
|
||||
@@ -82,8 +94,9 @@ class DefaultSpecies(BaseSpecies):
|
||||
|
||||
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
|
||||
|
||||
state = state.update(randkey=randkey)
|
||||
|
||||
return state.register(
|
||||
randkey=randkey,
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
species_keys=species_keys,
|
||||
@@ -97,7 +110,7 @@ class DefaultSpecies(BaseSpecies):
|
||||
)
|
||||
|
||||
def ask(self, state):
|
||||
return state, state.pop_nodes, state.pop_conns
|
||||
return state.pop_nodes, state.pop_conns
|
||||
|
||||
def update_species(self, state, fitness):
|
||||
# update the fitness of each species
|
||||
@@ -122,8 +135,8 @@ class DefaultSpecies(BaseSpecies):
|
||||
|
||||
k1, k2 = jax.random.split(state.randkey)
|
||||
# crossover info
|
||||
winner, loser, elite_mask = self.create_crossover_pair(
|
||||
state, k1, spawn_number, fitness
|
||||
state, winner, loser, elite_mask = self.create_crossover_pair(
|
||||
state, spawn_number, fitness
|
||||
)
|
||||
|
||||
return state.update(randkey=k2), winner, loser, elite_mask
|
||||
@@ -322,12 +335,12 @@ class DefaultSpecies(BaseSpecies):
|
||||
winner = jnp.where(is_part1_win, part1, part2)
|
||||
loser = jnp.where(is_part1_win, part2, part1)
|
||||
|
||||
return state(randkey=randkey), winner, loser, elite_mask
|
||||
return state.update(randkey=randkey), winner, loser, elite_mask
|
||||
|
||||
def speciate(self, state):
|
||||
# prepare distance functions
|
||||
o2p_distance_func = jax.vmap(
|
||||
self.distance, in_axes=(None, None, 0, 0)
|
||||
self.distance, in_axes=(None, None, None, 0, 0)
|
||||
) # one to population
|
||||
|
||||
# idx to specie key
|
||||
@@ -351,7 +364,7 @@ class DefaultSpecies(BaseSpecies):
|
||||
i, i2s, cns, ccs, o2c = carry
|
||||
|
||||
distances = o2p_distance_func(
|
||||
cns[i], ccs[i], state.pop_nodes, state.pop_conns
|
||||
state, cns[i], ccs[i], state.pop_nodes, state.pop_conns
|
||||
)
|
||||
|
||||
# find the closest one
|
||||
@@ -434,7 +447,7 @@ class DefaultSpecies(BaseSpecies):
|
||||
def speciate_by_threshold(i, i2s, cns, ccs, sk, o2c):
|
||||
# distance between such center genome and ppo genomes
|
||||
o2p_distance = o2p_distance_func(
|
||||
cns[i], ccs[i], state.pop_nodes, state.pop_conns
|
||||
state, cns[i], ccs[i], state.pop_nodes, state.pop_conns
|
||||
)
|
||||
|
||||
close_enough_mask = o2p_distance < self.compatibility_threshold
|
||||
@@ -508,14 +521,16 @@ class DefaultSpecies(BaseSpecies):
|
||||
next_species_key=next_species_key,
|
||||
)
|
||||
|
||||
def distance(self, nodes1, conns1, nodes2, conns2):
|
||||
def distance(self, state, nodes1, conns1, nodes2, conns2):
|
||||
"""
|
||||
The distance between two genomes
|
||||
"""
|
||||
d = self.node_distance(nodes1, nodes2) + self.conn_distance(conns1, conns2)
|
||||
d = self.node_distance(state, nodes1, nodes2) + self.conn_distance(
|
||||
state, conns1, conns2
|
||||
)
|
||||
return d
|
||||
|
||||
def node_distance(self, nodes1, nodes2):
|
||||
def node_distance(self, state, nodes1, nodes2):
|
||||
"""
|
||||
The distance of the nodes part for two genomes
|
||||
"""
|
||||
@@ -541,7 +556,9 @@ class DefaultSpecies(BaseSpecies):
|
||||
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
|
||||
# calculate the distance of homologous nodes
|
||||
hnd = jax.vmap(self.genome.node_gene.distance, in_axes=(0, 0))(fr, sr)
|
||||
hnd = jax.vmap(self.genome.node_gene.distance, in_axes=(None, 0, 0))(
|
||||
state, fr, sr
|
||||
) # homologous node distance
|
||||
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
|
||||
homologous_distance = jnp.sum(hnd * intersect_mask)
|
||||
|
||||
@@ -550,9 +567,11 @@ class DefaultSpecies(BaseSpecies):
|
||||
+ homologous_distance * self.compatibility_weight
|
||||
)
|
||||
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
|
||||
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
|
||||
|
||||
def conn_distance(self, conns1, conns2):
|
||||
return val
|
||||
|
||||
def conn_distance(self, state, conns1, conns2):
|
||||
"""
|
||||
The distance of the conns part for two genomes
|
||||
"""
|
||||
@@ -573,7 +592,9 @@ class DefaultSpecies(BaseSpecies):
|
||||
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
|
||||
|
||||
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
hcd = jax.vmap(self.genome.conn_gene.distance, in_axes=(0, 0))(fr, sr)
|
||||
hcd = jax.vmap(self.genome.conn_gene.distance, in_axes=(None, 0, 0))(
|
||||
state, fr, sr
|
||||
) # homologous connection distance
|
||||
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
|
||||
homologous_distance = jnp.sum(hcd * intersect_mask)
|
||||
|
||||
@@ -582,185 +603,6 @@ class DefaultSpecies(BaseSpecies):
|
||||
+ homologous_distance * self.compatibility_weight
|
||||
)
|
||||
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt)
|
||||
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
|
||||
|
||||
|
||||
def initialize_population(pop_size, genome, randkey, init_method="default"):
|
||||
rand_keys = jax.random.split(randkey, pop_size)
|
||||
|
||||
if init_method == "one_hidden_node":
|
||||
init_func = init_one_hidden_node
|
||||
elif init_method == "dense_hideen_layer":
|
||||
init_func = init_dense_hideen_layer
|
||||
elif init_method == "no_hidden_random":
|
||||
init_func = init_no_hidden_random
|
||||
else:
|
||||
raise ValueError("Unknown initialization method: {}".format(init_method))
|
||||
|
||||
pop_nodes, pop_conns = jax.vmap(init_func, in_axes=(None, 0))(genome, rand_keys)
|
||||
|
||||
return pop_nodes, pop_conns
|
||||
|
||||
|
||||
# one hidden node
|
||||
def init_one_hidden_node(genome, randkey):
|
||||
input_idx, output_idx = genome.input_idx, genome.output_idx
|
||||
new_node_key = max([*input_idx, *output_idx]) + 1
|
||||
|
||||
nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan)
|
||||
conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan)
|
||||
|
||||
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||
nodes = nodes.at[new_node_key, 0].set(new_node_key)
|
||||
|
||||
rand_keys_nodes = jax.random.split(
|
||||
randkey, num=len(input_idx) + len(output_idx) + 1
|
||||
)
|
||||
input_keys, output_keys, hidden_key = (
|
||||
rand_keys_nodes[: len(input_idx)],
|
||||
rand_keys_nodes[len(input_idx) : len(input_idx) + len(output_idx)],
|
||||
rand_keys_nodes[-1],
|
||||
)
|
||||
|
||||
node_attr_func = jax.vmap(genome.node_gene.new_attrs, in_axes=(None, 0))
|
||||
input_attrs = node_attr_func(input_keys)
|
||||
output_attrs = node_attr_func(output_keys)
|
||||
hidden_attrs = genome.node_gene.new_custom_attrs(hidden_key)
|
||||
|
||||
nodes = nodes.at[input_idx, 1:].set(input_attrs)
|
||||
nodes = nodes.at[output_idx, 1:].set(output_attrs)
|
||||
nodes = nodes.at[new_node_key, 1:].set(hidden_attrs)
|
||||
|
||||
input_conns = jnp.c_[input_idx, jnp.full_like(input_idx, new_node_key)]
|
||||
conns = conns.at[input_idx, 0:2].set(input_conns)
|
||||
conns = conns.at[input_idx, 2].set(True)
|
||||
|
||||
output_conns = jnp.c_[jnp.full_like(output_idx, new_node_key), output_idx]
|
||||
conns = conns.at[output_idx, 0:2].set(output_conns)
|
||||
conns = conns.at[output_idx, 2].set(True)
|
||||
|
||||
rand_keys_conns = jax.random.split(randkey, num=len(input_idx) + len(output_idx))
|
||||
input_conn_keys, output_conn_keys = (
|
||||
rand_keys_conns[: len(input_idx)],
|
||||
rand_keys_conns[len(input_idx) :],
|
||||
)
|
||||
|
||||
conn_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(None, 0))
|
||||
input_conn_attrs = conn_attr_func(input_conn_keys)
|
||||
output_conn_attrs = conn_attr_func(output_conn_keys)
|
||||
|
||||
conns = conns.at[input_idx, 3:].set(input_conn_attrs)
|
||||
conns = conns.at[output_idx, 3:].set(output_conn_attrs)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
|
||||
# random dense connections with 1 hidden layer
|
||||
def init_dense_hideen_layer(genome, randkey, hiddens=20):
|
||||
k1, k2, k3 = jax.random.split(randkey, num=3)
|
||||
input_idx, output_idx = genome.input_idx, genome.output_idx
|
||||
input_size = len(input_idx)
|
||||
output_size = len(output_idx)
|
||||
|
||||
hidden_idx = jnp.arange(
|
||||
input_size + output_size, input_size + output_size + hiddens
|
||||
)
|
||||
nodes = jnp.full(
|
||||
(genome.max_nodes, genome.node_gene.length), jnp.nan, dtype=jnp.float32
|
||||
)
|
||||
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||
nodes = nodes.at[hidden_idx, 0].set(hidden_idx)
|
||||
|
||||
total_idx = input_size + output_size + hiddens
|
||||
rand_keys_n = jax.random.split(k1, num=total_idx)
|
||||
input_keys = rand_keys_n[:input_size]
|
||||
output_keys = rand_keys_n[input_size : input_size + output_size]
|
||||
hidden_keys = rand_keys_n[input_size + output_size :]
|
||||
|
||||
node_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0))
|
||||
input_attrs = node_attr_func(input_keys)
|
||||
output_attrs = node_attr_func(output_keys)
|
||||
hidden_attrs = node_attr_func(hidden_keys)
|
||||
|
||||
nodes = nodes.at[input_idx, 1:].set(input_attrs)
|
||||
nodes = nodes.at[output_idx, 1:].set(output_attrs)
|
||||
nodes = nodes.at[hidden_idx, 1:].set(hidden_attrs)
|
||||
|
||||
total_connections = input_size * hiddens + hiddens * output_size
|
||||
conns = jnp.full(
|
||||
(genome.max_conns, genome.conn_gene.length), jnp.nan, dtype=jnp.float32
|
||||
)
|
||||
|
||||
rand_keys_c = jax.random.split(k2, num=total_connections)
|
||||
conns_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0))
|
||||
conns_attrs = conns_attr_func(rand_keys_c)
|
||||
|
||||
input_to_hidden_ids, hidden_ids = jnp.meshgrid(input_idx, hidden_idx, indexing="ij")
|
||||
hidden_to_output_ids, output_ids = jnp.meshgrid(
|
||||
hidden_idx, output_idx, indexing="ij"
|
||||
)
|
||||
|
||||
conns = conns.at[: input_size * hiddens, 0].set(input_to_hidden_ids.flatten())
|
||||
conns = conns.at[: input_size * hiddens, 1].set(hidden_ids.flatten())
|
||||
conns = conns.at[input_size * hiddens : total_connections, 0].set(
|
||||
hidden_to_output_ids.flatten()
|
||||
)
|
||||
conns = conns.at[input_size * hiddens : total_connections, 1].set(
|
||||
output_ids.flatten()
|
||||
)
|
||||
conns = conns.at[: input_size * hiddens + hiddens * output_size, 2].set(True)
|
||||
conns = conns.at[: input_size * hiddens + hiddens * output_size, 3:].set(
|
||||
conns_attrs
|
||||
)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
|
||||
# random sparse connections with no hidden nodes
|
||||
def init_no_hidden_random(genome, randkey):
|
||||
k1, k2, k3 = jax.random.split(randkey, num=3)
|
||||
input_idx, output_idx = genome.input_idx, genome.output_idx
|
||||
|
||||
nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan)
|
||||
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||
|
||||
total_idx = len(input_idx) + len(output_idx)
|
||||
rand_keys_n = jax.random.split(k1, num=total_idx)
|
||||
input_keys = rand_keys_n[: len(input_idx)]
|
||||
output_keys = rand_keys_n[len(input_idx) :]
|
||||
|
||||
node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0))
|
||||
input_attrs = node_attr_func(input_keys)
|
||||
output_attrs = node_attr_func(output_keys)
|
||||
nodes = nodes.at[input_idx, 1:].set(input_attrs)
|
||||
nodes = nodes.at[output_idx, 1:].set(output_attrs)
|
||||
|
||||
conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan)
|
||||
|
||||
num_connections_per_output = 4
|
||||
total_connections = len(output_idx) * num_connections_per_output
|
||||
|
||||
def create_connections_for_output(key):
|
||||
permuted_inputs = jax.random.permutation(key, input_idx)
|
||||
selected_inputs = permuted_inputs[:num_connections_per_output]
|
||||
return selected_inputs
|
||||
|
||||
conn_keys = jax.random.split(k2, num=len(output_idx))
|
||||
connections = jax.vmap(create_connections_for_output)(conn_keys)
|
||||
connections = connections.flatten()
|
||||
|
||||
output_repeats = jnp.repeat(output_idx, num_connections_per_output)
|
||||
|
||||
rand_keys_c = jax.random.split(k3, num=total_connections)
|
||||
conns_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0))
|
||||
conns_attrs = conns_attr_func(rand_keys_c)
|
||||
|
||||
conns = conns.at[:total_connections, 0].set(connections)
|
||||
conns = conns.at[:total_connections, 1].set(output_repeats)
|
||||
conns = conns.at[:total_connections, 2].set(True) # enabled
|
||||
conns = conns.at[:total_connections, 3:].set(conns_attrs)
|
||||
|
||||
return nodes, conns
|
||||
return val
|
||||
|
||||
Reference in New Issue
Block a user