new architecture

This commit is contained in:
wls2002
2024-01-27 00:52:39 +08:00
parent 4efe9a53c1
commit aac41a089d
65 changed files with 1651 additions and 1783 deletions

View File

@@ -1,2 +1,2 @@
from .neat import NEAT
from .hyperneat import HyperNEAT
from .base import BaseAlgorithm
from .neat import NEAT

24
algorithm/base.py Normal file
View File

@@ -0,0 +1,24 @@
from utils import State
class BaseAlgorithm:
def setup(self, randkey):
"""initialize the state of the algorithm"""
raise NotImplementedError
def ask(self, state: State):
"""require the population to be evaluated"""
raise NotImplementedError
def tell(self, state: State, fitness):
"""update the state of the algorithm"""
raise NotImplementedError
def transform(self, state: State):
"""transform the genome into a neural network"""
raise NotImplementedError
def forward(self, inputs, transformed):
raise NotImplementedError

View File

@@ -1,2 +0,0 @@
from .hyperneat import HyperNEAT
from .substrate import NormalSubstrate, NormalSubstrateConfig

View File

@@ -1,113 +0,0 @@
from typing import Type
import jax
from jax import numpy as jnp, Array, vmap
import numpy as np
from config import Config, HyperNeatConfig
from core import Algorithm, Substrate, State, Genome, Gene
from .substrate import analysis_substrate
from algorithm import NEAT
class HyperNEAT(Algorithm):
def __init__(self, config: Config, gene: Type[Gene], substrate: Type[Substrate]):
self.config = config
self.neat = NEAT(config, gene)
self.substrate = substrate
def setup(self, randkey, state=State()):
neat_key, randkey = jax.random.split(randkey)
state = state.update(
below_threshold=self.config.hyperneat.below_threshold,
max_weight=self.config.hyperneat.max_weight,
)
state = self.neat.setup(neat_key, state)
state = self.substrate.setup(self.config.substrate, state)
assert self.config.hyperneat.inputs + 1 == state.input_coors.shape[0] # +1 for bias
assert self.config.hyperneat.outputs == state.output_coors.shape[0]
h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state)
h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis]
h_conns = np.zeros((correspond_keys.shape[0], 3), dtype=np.float32)
h_conns[:, 0:2] = correspond_keys
state = state.update(
h_input_idx=h_input_idx,
h_output_idx=h_output_idx,
h_hidden_idx=h_hidden_idx,
h_nodes=h_nodes,
h_conns=h_conns,
query_coors=query_coors,
)
return state
def ask_algorithm(self, state: State):
return state.pop_genomes
def tell_algorithm(self, state: State, fitness):
return self.neat.tell(state, fitness)
def forward(self, state, inputs: Array, transformed: Array):
return HyperNEATGene.forward(self.config.hyperneat, state, inputs, transformed)
def forward_transform(self, state: State, genome: Genome):
t = self.neat.forward_transform(state, genome)
query_res = vmap(self.neat.forward, in_axes=(None, 0, None))(state, state.query_coors, t)
# mute the connection with weight below threshold
query_res = jnp.where((-state.below_threshold < query_res) & (query_res < state.below_threshold), 0., query_res)
# make query res in range [-max_weight, max_weight]
query_res = jnp.where(query_res > 0, query_res - state.below_threshold, query_res)
query_res = jnp.where(query_res < 0, query_res + state.below_threshold, query_res)
query_res = query_res / (1 - state.below_threshold) * state.max_weight
h_conns = state.h_conns.at[:, 2:].set(query_res)
return HyperNEATGene.forward_transform(Genome(state.h_nodes, h_conns))
class HyperNEATGene:
node_attrs = [] # no node attributes
conn_attrs = ['weight']
@staticmethod
def forward_transform(genome: Genome):
N = genome.nodes.shape[0]
u_conns = jnp.zeros((N, N), dtype=jnp.float32)
in_keys = jnp.asarray(genome.conns[:, 0], jnp.int32)
out_keys = jnp.asarray(genome.conns[:, 1], jnp.int32)
weights = genome.conns[:, 2]
u_conns = u_conns.at[in_keys, out_keys].set(weights)
return genome.nodes, u_conns
@staticmethod
def forward(config: HyperNeatConfig, state: State, inputs, transformed):
batch_act, batch_agg = jax.vmap(config.activation), jax.vmap(config.aggregation)
nodes, weights = transformed
inputs_with_bias = jnp.concatenate((inputs, jnp.ones((1,))), axis=0)
input_idx = state.h_input_idx
output_idx = state.h_output_idx
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
def body_func(i, values):
values = values.at[input_idx].set(inputs_with_bias)
nodes_ins = values * weights.T
values = batch_agg(nodes_ins) # z = agg(ins)
# values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(values) # z = act(z)
return values
vals = jax.lax.fori_loop(0, config.activate_times, body_func, vals)
return vals[output_idx]

View File

@@ -1,2 +0,0 @@
from .normal import NormalSubstrate, NormalSubstrateConfig
from .tools import analysis_substrate

View File

@@ -1,25 +0,0 @@
from dataclasses import dataclass
from typing import Tuple
import numpy as np
from core import Substrate, State
from config import SubstrateConfig
@dataclass(frozen=True)
class NormalSubstrateConfig(SubstrateConfig):
input_coors: Tuple = ((-1, -1), (0, -1), (1, -1))
hidden_coors: Tuple = ((-1, 0), (0, 0), (1, 0))
output_coors: Tuple = ((0, 1),)
class NormalSubstrate(Substrate):
@staticmethod
def setup(config: NormalSubstrateConfig, state: State = State()):
return state.update(
input_coors=np.asarray(config.input_coors, dtype=np.float32),
output_coors=np.asarray(config.output_coors, dtype=np.float32),
hidden_coors=np.asarray(config.hidden_coors, dtype=np.float32),
)

View File

@@ -1,49 +0,0 @@
import numpy as np
def analysis_substrate(state):
cd = state.input_coors.shape[1] # coordinate dimensions
si = state.input_coors.shape[0] # input coordinate size
so = state.output_coors.shape[0] # output coordinate size
sh = state.hidden_coors.shape[0] # hidden coordinate size
input_idx = np.arange(si)
output_idx = np.arange(si, si + so)
hidden_idx = np.arange(si + so, si + so + sh)
total_conns = si * sh + sh * sh + sh * so
query_coors = np.zeros((total_conns, cd * 2))
correspond_keys = np.zeros((total_conns, 2))
# connect input to hidden
aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, state.input_coors, state.hidden_coors)
query_coors[0: si * sh, :] = aux_coors
correspond_keys[0: si * sh, :] = aux_keys
# connect hidden to hidden
aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, state.hidden_coors, state.hidden_coors)
query_coors[si * sh: si * sh + sh * sh, :] = aux_coors
correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys
# connect hidden to output
aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, state.hidden_coors, state.output_coors)
query_coors[si * sh + sh * sh:, :] = aux_coors
correspond_keys[si * sh + sh * sh:, :] = aux_keys
return input_idx, output_idx, hidden_idx, query_coors, correspond_keys
def cartesian_product(keys1, keys2, coors1, coors2):
len1 = keys1.shape[0]
len2 = keys2.shape[0]
repeated_coors1 = np.repeat(coors1, len2, axis=0)
repeated_keys1 = np.repeat(keys1, len2)
tiled_coors2 = np.tile(coors2, (len1, 1))
tiled_keys2 = np.tile(keys2, len1)
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
return new_coors, correspond_keys

View File

@@ -1,2 +1,3 @@
from .neat import NEAT
from .gene import *
from .genome import *
from .neat import NEAT

View File

@@ -1,3 +1,2 @@
from .crossover import crossover
from .mutate import mutate
from .operation import create_next_generation
from .crossover import BaseCrossover, DefaultCrossover
from .mutation import BaseMutation, DefaultMutation

View File

@@ -1,70 +0,0 @@
import jax
from jax import Array, numpy as jnp
from core import Genome
def crossover(randkey, genome1: Genome, genome2: Genome):
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
"""
randkey_1, randkey_2, key = jax.random.split(randkey, 3)
# crossover nodes
keys1, keys2 = genome1.nodes[:, 0], genome2.nodes[:, 0]
# make homologous genes align in nodes2 align with nodes1
nodes2 = align_array(keys1, keys2, genome2.nodes, False)
nodes1 = genome1.nodes
# For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections
con_keys1, con_keys2 = genome1.conns[:, :2], genome2.conns[:, :2]
conns2 = align_array(con_keys1, con_keys2, genome2.conns, True)
conns1 = genome1.conns
new_cons = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, crossover_gene(randkey_2, conns1, conns2))
return genome1.update(new_nodes, new_cons)
def align_array(seq1: Array, seq2: Array, ar2: Array, is_conn: bool) -> Array:
"""
After I review this code, I found that it is the most difficult part of the code. Please never change it!
make ar2 align with ar1.
:param seq1:
:param seq2:
:param ar2:
:param is_conn:
:return:
align means to intersect part of ar2 will be at the same position as ar1,
non-intersect part of ar2 will be set to Nan
"""
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
if is_conn:
mask = jnp.all(mask, axis=2)
intersect_mask = mask.any(axis=1)
idx = jnp.arange(0, len(seq1))
idx_fixed = jnp.dot(mask, idx)
refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan)
return refactor_ar2
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
"""
crossover two genes
:param rand_key:
:param g1:
:param g2:
:return:
only gene with the same key will be crossover, thus don't need to consider change key
"""
r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2)

View File

@@ -0,0 +1,2 @@
from .base import BaseCrossover
from .default import DefaultCrossover

View File

@@ -0,0 +1,3 @@
class BaseCrossover:
def __call__(self, randkey, genome, nodes1, nodes2, conns1, conns2):
raise NotImplementedError

View File

@@ -0,0 +1,66 @@
import jax, jax.numpy as jnp
from .base import BaseCrossover
class DefaultCrossover(BaseCrossover):
def __call__(self, randkey, genome, nodes1, nodes2, conns1, conns2):
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
"""
randkey_1, randkey_2, key = jax.random.split(randkey, 3)
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
# make homologous genes align in nodes2 align with nodes1
nodes2 = self.align_array(keys1, keys2, nodes2, False)
# For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, self.crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
conns2 = self.align_array(con_keys1, con_keys2, conns2, True)
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, self.crossover_gene(randkey_2, conns1, conns2))
return new_nodes, new_conns
def align_array(self, seq1, seq2, ar2, is_conn: bool):
"""
After I review this code, I found that it is the most difficult part of the code. Please never change it!
make ar2 align with ar1.
:param seq1:
:param seq2:
:param ar2:
:param is_conn:
:return:
align means to intersect part of ar2 will be at the same position as ar1,
non-intersect part of ar2 will be set to Nan
"""
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
if is_conn:
mask = jnp.all(mask, axis=2)
intersect_mask = mask.any(axis=1)
idx = jnp.arange(0, len(seq1))
idx_fixed = jnp.dot(mask, idx)
refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan)
return refactor_ar2
def crossover_gene(self, rand_key, g1, g2):
"""
crossover two genes
:param rand_key:
:param g1:
:param g2:
:return:
only gene with the same key will be crossover, thus don't need to consider change key
"""
r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2)

View File

@@ -1,186 +0,0 @@
from typing import Tuple
import jax
from jax import Array, numpy as jnp, vmap
from config import NeatConfig
from core import State, Gene, Genome
from utils import check_cycles, fetch_random, fetch_first, I_INT, unflatten_conns
def mutate(config: NeatConfig, gene: Gene, state: State, randkey, genome: Genome, new_node_key):
"""
Mutate a population of genomes
"""
k1, k2 = jax.random.split(randkey)
genome = mutate_structure(config, gene, state, k1, genome, new_node_key)
genome = mutate_values(gene, state, randkey, genome)
return genome
def mutate_structure(config: NeatConfig, gene: Gene, state: State, randkey, genome: Genome, new_node_key):
def mutate_add_node(key_, genome_: Genome):
i_key, o_key, idx = choice_connection_key(key_, genome_.conns)
def nothing():
return genome_
def successful_add_node():
# disable the connection
new_genome = genome_.update_conns(genome_.conns.at[idx, 2].set(False))
# add a new node
new_genome = new_genome.add_node(new_node_key, gene.new_node_attrs(state))
# add two new connections
new_genome = new_genome.add_conn(i_key, new_node_key, True, gene.new_conn_attrs(state))
new_genome = new_genome.add_conn(new_node_key, o_key, True, gene.new_conn_attrs(state))
return new_genome
# if from_idx == I_INT, that means no connection exist, do nothing
return jax.lax.cond(idx == I_INT, nothing, successful_add_node)
def mutate_delete_node(key_, genome_: Genome):
# TODO: Do we really need to delete a node?
# randomly choose a node
key, idx = choice_node_key(key_, genome_.nodes, state.input_idx, state.output_idx,
allow_input_keys=False, allow_output_keys=False)
def nothing():
return genome_
def successful_delete_node():
# delete the node
new_genome = genome_.delete_node_by_pos(idx)
# delete all connections
new_conns = jnp.where(((new_genome.conns[:, 0] == key) | (new_genome.conns[:, 1] == key))[:, None],
jnp.nan, new_genome.conns)
return new_genome.update_conns(new_conns)
return jax.lax.cond(idx == I_INT, nothing, successful_delete_node)
def mutate_add_conn(key_, genome_: Genome):
# randomly choose two nodes
k1_, k2_ = jax.random.split(key_, num=2)
i_key, from_idx = choice_node_key(k1_, genome_.nodes, state.input_idx, state.output_idx,
allow_input_keys=True, allow_output_keys=True)
o_key, to_idx = choice_node_key(k2_, genome_.nodes, state.input_idx, state.output_idx,
allow_input_keys=False, allow_output_keys=True)
conn_pos = fetch_first((genome_.conns[:, 0] == i_key) & (genome_.conns[:, 1] == o_key))
def nothing():
return genome_
def successful():
return genome_.add_conn(i_key, o_key, True, gene.new_conn_attrs(state))
def already_exist():
return genome_.update_conns(genome_.conns.at[conn_pos, 2].set(True))
is_already_exist = conn_pos != I_INT
if config.network_type == 'feedforward':
u_cons = unflatten_conns(genome_.nodes, genome_.conns)
cons_exist = jnp.where(~jnp.isnan(u_cons[0, :, :]), True, False)
is_cycle = check_cycles(genome_.nodes, cons_exist, from_idx, to_idx)
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
return jax.lax.switch(choice, [already_exist, nothing, successful])
elif config.network_type == 'recurrent':
return jax.lax.cond(is_already_exist, already_exist, successful)
else:
raise ValueError(f"Invalid network type: {config.network_type}")
def mutate_delete_conn(key_, genome_: Genome):
# randomly choose a connection
i_key, o_key, idx = choice_connection_key(key_, genome_.conns)
def nothing():
return genome_
def successfully_delete_connection():
return genome_.delete_conn_by_pos(idx)
return jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection)
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
def no(k, g):
return g
genome = jax.lax.cond(r1 < config.node_add, mutate_add_node, no, k1, genome)
genome = jax.lax.cond(r2 < config.node_delete, mutate_delete_node, no, k2, genome)
genome = jax.lax.cond(r3 < config.conn_add, mutate_add_conn, no, k3, genome)
genome = jax.lax.cond(r4 < config.conn_delete, mutate_delete_conn, no, k4, genome)
return genome
def mutate_values(gene: Gene, state: State, randkey, genome: Genome):
k1, k2 = jax.random.split(randkey, num=2)
nodes_keys = jax.random.split(k1, num=genome.nodes.shape[0])
conns_keys = jax.random.split(k2, num=genome.conns.shape[0])
nodes_attrs, conns_attrs = genome.nodes[:, 1:], genome.conns[:, 3:]
new_nodes_attrs = vmap(gene.mutate_node, in_axes=(None, 0, 0))(state, nodes_keys, nodes_attrs)
new_conns_attrs = vmap(gene.mutate_conn, in_axes=(None, 0, 0))(state, conns_keys, conns_attrs)
# nan nodes not changed
new_nodes_attrs = jnp.where(jnp.isnan(nodes_attrs), jnp.nan, new_nodes_attrs)
new_conns_attrs = jnp.where(jnp.isnan(conns_attrs), jnp.nan, new_conns_attrs)
new_nodes = genome.nodes.at[:, 1:].set(new_nodes_attrs)
new_conns = genome.conns.at[:, 3:].set(new_conns_attrs)
return genome.update(new_nodes, new_conns)
def choice_node_key(rand_key: Array, nodes: Array,
input_keys: Array, output_keys: Array,
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
"""
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
:param rand_key:
:param nodes:
:param input_keys:
:param output_keys:
:param allow_input_keys:
:param allow_output_keys:
:return: return its key and position(idx)
"""
node_keys = nodes[:, 0]
mask = ~jnp.isnan(node_keys)
if not allow_input_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys))
if not allow_output_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys))
idx = fetch_random(rand_key, mask)
key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan)
return key, idx
def choice_connection_key(rand_key: Array, conns: Array):
"""
Randomly choose a connection key from the given connections.
:return: i_key, o_key, idx
"""
idx = fetch_random(rand_key, ~jnp.isnan(conns[:, 0]))
i_key = jnp.where(idx != I_INT, conns[idx, 0], jnp.nan)
o_key = jnp.where(idx != I_INT, conns[idx, 1], jnp.nan)
return i_key, o_key, idx

View File

@@ -0,0 +1,2 @@
from .base import BaseMutation
from .default import DefaultMutation

View File

@@ -0,0 +1,3 @@
class BaseMutation:
def __call__(self, key, genome, nodes, conns, new_node_key):
raise NotImplementedError

View File

@@ -0,0 +1,201 @@
import jax, jax.numpy as jnp
from . import BaseMutation
from utils import fetch_first, fetch_random, I_INT, unflatten_conns, check_cycles
class DefaultMutation(BaseMutation):
def __init__(
self,
conn_add: float = 0.4,
conn_delete: float = 0,
node_add: float = 0.2,
node_delete: float = 0,
):
self.conn_add = conn_add
self.conn_delete = conn_delete
self.node_add = node_add
self.node_delete = node_delete
def __call__(self, randkey, genome, nodes, conns, new_node_key):
k1, k2 = jax.random.split(randkey)
nodes, conns = self.mutate_structure(k1, genome, nodes, conns, new_node_key)
nodes, conns = self.mutate_values(k2, genome, nodes, conns)
return nodes, conns
def mutate_structure(self, randkey, genome, nodes, conns, new_node_key):
def mutate_add_node(key_, nodes_, conns_):
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
def successful_add_node():
# disable the connection
new_conns = conns_.at[idx, 2].set(False)
# add a new node
new_nodes = genome.add_node(nodes_, new_node_key, genome.node_gene.new_custom_attrs())
# add two new connections
new_conns = genome.add_conn(new_conns, i_key, new_node_key, True, genome.conn_gene.new_custom_attrs())
new_conns = genome.add_conn(new_conns, new_node_key, o_key, True, genome.conn_gene.new_custom_attrs())
return new_nodes, new_conns
return jax.lax.cond(
idx == I_INT,
lambda: (nodes_, conns_), # do nothing
successful_add_node
)
def mutate_delete_node(key_, nodes_, conns_):
# randomly choose a node
key, idx = self.choice_node_key(key_, nodes_, genome.input_idx, genome.output_idx,
allow_input_keys=False, allow_output_keys=False)
def successful_delete_node():
# delete the node
new_nodes = genome.delete_node_by_pos(nodes_, idx)
# delete all connections
new_conns = jnp.where(
((conns_[:, 0] == key) | (conns_[:, 1] == key))[:, None],
jnp.nan,
conns_
)
return new_nodes, new_conns
return jax.lax.cond(
idx == I_INT,
lambda: (nodes_, conns_), # do nothing
successful_delete_node
)
def mutate_add_conn(key_, nodes_, conns_):
# randomly choose two nodes
k1_, k2_ = jax.random.split(key_, num=2)
# input node of the connection can be any node
i_key, from_idx = self.choice_node_key(k1_, nodes_, genome.input_idx, genome.output_idx,
allow_input_keys=True, allow_output_keys=True)
# output node of the connection can be any node except input node
o_key, to_idx = self.choice_node_key(k2_, nodes_, genome.input_idx, genome.output_idx,
allow_input_keys=False, allow_output_keys=True)
conn_pos = fetch_first((conns_[:, 0] == i_key) & (conns_[:, 1] == o_key))
is_already_exist = conn_pos != I_INT
def nothing():
return nodes_, conns_
def successful():
return nodes_, genome.add_conn(conns_, i_key, o_key, True, genome.conns.new_custom_attrs())
def already_exist():
return nodes_, conns_.at[conn_pos, 2].set(True)
if genome.network_type == 'feedforward':
u_cons = unflatten_conns(nodes_, conns_)
cons_exist = ~jnp.isnan(u_cons[0, :, :])
is_cycle = check_cycles(nodes_, cons_exist, from_idx, to_idx)
return jax.lax.cond(
is_already_exist,
already_exist,
jax.lax.cond(
is_cycle,
nothing,
successful
)
)
elif genome.network_type == 'recurrent':
return jax.lax.cond(
is_already_exist,
already_exist,
successful
)
else:
raise ValueError(f"Invalid network type: {genome.network_type}")
def mutate_delete_conn(key_, nodes_, conns_):
# randomly choose a connection
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
def successfully_delete_connection():
return nodes_, genome.delete_conn_by_pos(conns_, idx)
return jax.lax.cond(
idx == I_INT,
lambda: (nodes_, conns_), # nothing
successfully_delete_connection
)
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
def no(k, g):
return g
genome = jax.lax.cond(r1 < self.node_add, mutate_add_node, no, k1, nodes, conns)
genome = jax.lax.cond(r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns)
genome = jax.lax.cond(r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns)
genome = jax.lax.cond(r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns)
return genome
def mutate_values(self, randkey, genome, nodes, conns):
k1, k2 = jax.random.split(randkey, num=2)
nodes_keys = jax.random.split(k1, num=genome.nodes.shape[0])
conns_keys = jax.random.split(k2, num=genome.conns.shape[0])
new_nodes = jax.vmap(genome.nodes.mutate, in_axes=(0, 0))(nodes_keys, nodes)
new_conns = jax.vmap(genome.conns.mutate, in_axes=(0, 0))(conns_keys, conns)
# nan nodes not changed
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)
new_conns = jnp.where(jnp.isnan(conns), jnp.nan, new_conns)
return new_nodes, new_conns
def choice_node_key(self, rand_key, nodes, input_idx, output_idx,
allow_input_keys: bool = False, allow_output_keys: bool = False):
"""
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
:param rand_key:
:param nodes:
:param input_idx:
:param output_idx:
:param allow_input_keys:
:param allow_output_keys:
:return: return its key and position(idx)
"""
node_keys = nodes[:, 0]
mask = ~jnp.isnan(node_keys)
if not allow_input_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_idx))
if not allow_output_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_idx))
idx = fetch_random(rand_key, mask)
key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan)
return key, idx
def choice_connection_key(self, rand_key, conns):
"""
Randomly choose a connection key from the given connections.
:return: i_key, o_key, idx
"""
idx = fetch_random(rand_key, ~jnp.isnan(conns[:, 0]))
i_key = jnp.where(idx != I_INT, conns[idx, 0], jnp.nan)
o_key = jnp.where(idx != I_INT, conns[idx, 1], jnp.nan)
return i_key, o_key, idx

View File

@@ -1,40 +0,0 @@
import jax
from jax import numpy as jnp, vmap
from config import NeatConfig
from core import Genome, State, Gene
from .mutate import mutate
from .crossover import crossover
def create_next_generation(config: NeatConfig, gene: Gene, state: State, randkey, winner, loser, elite_mask):
# prepare random keys
pop_size = state.idx2species.shape[0]
new_node_keys = jnp.arange(pop_size) + state.next_node_key
k1, k2 = jax.random.split(randkey, 2)
crossover_rand_keys = jax.random.split(k1, pop_size)
mutate_rand_keys = jax.random.split(k2, pop_size)
# batch crossover
wpn, wpc = state.pop_genomes.nodes[winner], state.pop_genomes.conns[winner]
lpn, lpc = state.pop_genomes.nodes[loser], state.pop_genomes.conns[loser]
n_genomes = vmap(crossover)(crossover_rand_keys, Genome(wpn, wpc), Genome(lpn, lpc))
# batch mutation
mutate_func = vmap(mutate, in_axes=(None, None, None, 0, 0, 0))
m_n_genomes = mutate_func(config, gene, state, mutate_rand_keys, n_genomes, new_node_keys) # mutate_new_pop_nodes
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_genomes.nodes, m_n_genomes.nodes)
pop_conns = jnp.where(elite_mask[:, None, None], n_genomes.conns, m_n_genomes.conns)
# update next node key
all_nodes_keys = pop_nodes[:, :, 0]
max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys))
next_node_key = max_node_key + 1
return state.update(
pop_genomes=Genome(pop_nodes, pop_conns),
next_node_key=next_node_key,
)

View File

@@ -1,3 +1,3 @@
from .normal import NormalGene, NormalGeneConfig
from .recurrent import RecurrentGene, RecurrentGeneConfig
from .base import BaseGene
from .conn import *
from .node import *

View File

@@ -0,0 +1,23 @@
class BaseGene:
"Base class for node genes or connection genes."
fixed_attrs = []
custom_attrs = []
def __init__(self):
pass
def new_custom_attrs(self):
raise NotImplementedError
def mutate(self, randkey, gene):
raise NotImplementedError
def distance(self, gene1, gene2):
raise NotImplementedError
def forward(self, attrs, inputs):
raise NotImplementedError
@property
def length(self):
return len(self.fixed_attrs) + len(self.custom_attrs)

View File

@@ -0,0 +1,2 @@
from .base import BaseConnGene
from .default import DefaultConnGene

View File

@@ -0,0 +1,12 @@
from .. import BaseGene
class BaseConnGene(BaseGene):
"Base class for connection genes."
fixed_attrs = ['input_index', 'output_index', 'enabled']
def __init__(self):
super().__init__()
def forward(self, attrs, inputs):
raise NotImplementedError

View File

@@ -0,0 +1,51 @@
import jax.numpy as jnp
from utils import mutate_float
from . import BaseConnGene
class DefaultConnGene(BaseConnGene):
"Default connection gene, with the same behavior as in NEAT-python."
fixed_attrs = ['input_index', 'output_index', 'enabled']
attrs = ['weight']
def __init__(
self,
weight_init_mean: float = 0.0,
weight_init_std: float = 1.0,
weight_mutate_power: float = 0.5,
weight_mutate_rate: float = 0.8,
weight_replace_rate: float = 0.1,
):
super().__init__()
self.weight_init_mean = weight_init_mean
self.weight_init_std = weight_init_std
self.weight_mutate_power = weight_mutate_power
self.weight_mutate_rate = weight_mutate_rate
self.weight_replace_rate = weight_replace_rate
def new_custom_attrs(self):
return jnp.array([self.weight_init_mean])
def mutate(self, key, conn):
input_index = conn[0]
output_index = conn[1]
enabled = conn[2]
weight = mutate_float(key,
conn[3],
self.weight_init_mean,
self.weight_init_std,
self.weight_mutate_power,
self.weight_mutate_rate,
self.weight_replace_rate
)
return jnp.array([input_index, output_index, enabled, weight])
def distance(self, attrs1, attrs2):
return (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
def forward(self, attrs, inputs):
weight = attrs[0]
return inputs * weight

View File

@@ -0,0 +1,2 @@
from .base import BaseNodeGene
from .default import DefaultNodeGene

View File

@@ -0,0 +1,12 @@
from .. import BaseGene
class BaseNodeGene(BaseGene):
"Base class for node genes."
fixed_attrs = ["index"]
def __init__(self):
super().__init__()
def forward(self, attrs, inputs):
raise NotImplementedError

View File

@@ -0,0 +1,96 @@
from typing import Tuple
import jax, jax.numpy as jnp
from utils import Act, Agg, act, agg, mutate_int, mutate_float
from . import BaseNodeGene
class DefaultNodeGene(BaseNodeGene):
"Default node gene, with the same behavior as in NEAT-python."
fixed_attrs = ['index']
custom_attrs = ['bias', 'response', 'aggregation', 'activation']
def __init__(
self,
bias_init_mean: float = 0.0,
bias_init_std: float = 1.0,
bias_mutate_power: float = 0.5,
bias_mutate_rate: float = 0.7,
bias_replace_rate: float = 0.1,
response_init_mean: float = 1.0,
response_init_std: float = 0.0,
response_mutate_power: float = 0.5,
response_mutate_rate: float = 0.7,
response_replace_rate: float = 0.1,
activation_default: callable = Act.sigmoid,
activation_options: Tuple = (Act.sigmoid,),
activation_replace_rate: float = 0.1,
aggregation_default: callable = Agg.sum,
aggregation_options: Tuple = (Agg.sum,),
aggregation_replace_rate: float = 0.1,
):
super().__init__()
self.bias_init_mean = bias_init_mean
self.bias_init_std = bias_init_std
self.bias_mutate_power = bias_mutate_power
self.bias_mutate_rate = bias_mutate_rate
self.bias_replace_rate = bias_replace_rate
self.response_init_mean = response_init_mean
self.response_init_std = response_init_std
self.response_mutate_power = response_mutate_power
self.response_mutate_rate = response_mutate_rate
self.response_replace_rate = response_replace_rate
self.activation_default = activation_options.index(activation_default)
self.activation_options = activation_options
self.activation_indices = jnp.arange(len(activation_options))
self.activation_replace_rate = activation_replace_rate
self.aggregation_default = aggregation_options.index(aggregation_default)
self.aggregation_options = aggregation_options
self.aggregation_indices = jnp.arange(len(aggregation_options))
self.aggregation_replace_rate = aggregation_replace_rate
def new_custom_attrs(self):
return jnp.array(
[self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default]
)
def mutate(self, key, node):
k1, k2, k3, k4 = jax.random.split(key, num=4)
index = node[0]
bias = mutate_float(k1, node[1], self.bias_init_mean, self.bias_init_std,
self.bias_mutate_power, self.bias_mutate_rate, self.bias_replace_rate)
res = mutate_float(k2, node[2], self.response_init_mean, self.response_init_std,
self.response_mutate_power, self.response_mutate_rate, self.response_replace_rate)
act = mutate_int(k3, node[3], self.activation_indices, self.activation_replace_rate)
agg = mutate_int(k4, node[4], self.aggregation_indices, self.aggregation_replace_rate)
return jnp.array([index, bias, res, act, agg])
def distance(self, node1, node2):
return (
jnp.abs(node1[1] - node2[1]) +
jnp.abs(node1[2] - node2[2]) +
node1[3] != node2[3] +
node1[4] != node2[4]
)
def forward(self, attrs, inputs):
bias, res, act_idx, agg_idx = attrs
z = agg(agg_idx, inputs, self.aggregation_options)
z = bias + res * z
z = act(act_idx, z, self.activation_options)
return z

View File

@@ -1,210 +0,0 @@
from dataclasses import dataclass
from typing import Tuple
import jax
from jax import Array, numpy as jnp
from config import GeneConfig
from core import Gene, Genome, State
from utils import Act, Agg, unflatten_conns, topological_sort, I_INT, act, agg
@dataclass(frozen=True)
class NormalGeneConfig(GeneConfig):
bias_init_mean: float = 0.0
bias_init_std: float = 1.0
bias_mutate_power: float = 0.5
bias_mutate_rate: float = 0.7
bias_replace_rate: float = 0.1
response_init_mean: float = 1.0
response_init_std: float = 0.0
response_mutate_power: float = 0.5
response_mutate_rate: float = 0.7
response_replace_rate: float = 0.1
activation_default: callable = Act.sigmoid
activation_options: Tuple = (Act.sigmoid, )
activation_replace_rate: float = 0.1
aggregation_default: callable = Agg.sum
aggregation_options: Tuple = (Agg.sum, )
aggregation_replace_rate: float = 0.1
weight_init_mean: float = 0.0
weight_init_std: float = 1.0
weight_mutate_power: float = 0.5
weight_mutate_rate: float = 0.8
weight_replace_rate: float = 0.1
def __post_init__(self):
assert self.bias_init_std >= 0.0
assert self.bias_mutate_power >= 0.0
assert self.bias_mutate_rate >= 0.0
assert self.bias_replace_rate >= 0.0
assert self.response_init_std >= 0.0
assert self.response_mutate_power >= 0.0
assert self.response_mutate_rate >= 0.0
assert self.response_replace_rate >= 0.0
assert self.activation_default == self.activation_options[0]
assert self.aggregation_default == self.aggregation_options[0]
class NormalGene(Gene):
node_attrs = ['bias', 'response', 'aggregation', 'activation']
conn_attrs = ['weight']
def __init__(self, config: NormalGeneConfig = NormalGeneConfig()):
self.config = config
def setup(self, state: State = State()):
return state.update(
bias_init_mean=self.config.bias_init_mean,
bias_init_std=self.config.bias_init_std,
bias_mutate_power=self.config.bias_mutate_power,
bias_mutate_rate=self.config.bias_mutate_rate,
bias_replace_rate=self.config.bias_replace_rate,
response_init_mean=self.config.response_init_mean,
response_init_std=self.config.response_init_std,
response_mutate_power=self.config.response_mutate_power,
response_mutate_rate=self.config.response_mutate_rate,
response_replace_rate=self.config.response_replace_rate,
activation_replace_rate=self.config.activation_replace_rate,
activation_default=0,
activation_options=jnp.arange(len(self.config.activation_options)),
aggregation_replace_rate=self.config.aggregation_replace_rate,
aggregation_default=0,
aggregation_options=jnp.arange(len(self.config.aggregation_options)),
weight_init_mean=self.config.weight_init_mean,
weight_init_std=self.config.weight_init_std,
weight_mutate_power=self.config.weight_mutate_power,
weight_mutate_rate=self.config.weight_mutate_rate,
weight_replace_rate=self.config.weight_replace_rate,
)
def update(self, state):
return state
def new_node_attrs(self, state):
return jnp.array([state.bias_init_mean, state.response_init_mean,
state.activation_default, state.aggregation_default])
def new_conn_attrs(self, state):
return jnp.array([state.weight_init_mean])
def mutate_node(self, state, key, attrs: Array):
k1, k2, k3, k4 = jax.random.split(key, num=4)
bias = NormalGene._mutate_float(k1, attrs[0], state.bias_init_mean, state.bias_init_std,
state.bias_mutate_power, state.bias_mutate_rate, state.bias_replace_rate)
res = NormalGene._mutate_float(k2, attrs[1], state.response_init_mean, state.response_init_std,
state.response_mutate_power, state.response_mutate_rate,
state.response_replace_rate)
act = NormalGene._mutate_int(k3, attrs[2], state.activation_options, state.activation_replace_rate)
agg = NormalGene._mutate_int(k4, attrs[3], state.aggregation_options, state.aggregation_replace_rate)
return jnp.array([bias, res, act, agg])
def mutate_conn(self, state, key, attrs: Array):
weight = NormalGene._mutate_float(key, attrs[0], state.weight_init_mean, state.weight_init_std,
state.weight_mutate_power, state.weight_mutate_rate,
state.weight_replace_rate)
return jnp.array([weight])
def distance_node(self, state, node1: Array, node2: Array):
# bias + response + activation + aggregation
return jnp.abs(node1[1] - node2[1]) + jnp.abs(node1[2] - node2[2]) + \
(node1[3] != node2[3]) + (node1[4] != node2[4])
def distance_conn(self, state, con1: Array, con2: Array):
return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight
def forward_transform(self, state: State, genome: Genome):
u_conns = unflatten_conns(genome.nodes, genome.conns)
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
# remove enable attr
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
seqs = topological_sort(genome.nodes, conn_enable)
return seqs, genome.nodes, u_conns
def forward(self, state: State, inputs, transformed):
cal_seqs, nodes, cons = transformed
input_idx = state.input_idx
output_idx = state.output_idx
N = nodes.shape[0]
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
weights = cons[0, :]
def cond_fun(carry):
values, idx = carry
return (idx < N) & (cal_seqs[idx] != I_INT)
def body_func(carry):
values, idx = carry
i = cal_seqs[idx]
def hit():
ins = values * weights[:, i]
z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins)
z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z)
new_values = values.at[i].set(z)
return new_values
def miss():
return values
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
return values, idx + 1
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
return vals[output_idx]
@staticmethod
def _mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate):
k1, k2, k3 = jax.random.split(key, num=3)
noise = jax.random.normal(k1, ()) * mutate_power
replace = jax.random.normal(k2, ()) * init_std + init_mean
r = jax.random.uniform(k3, ())
val = jnp.where(
r < mutate_rate,
val + noise,
jnp.where(
(mutate_rate < r) & (r < mutate_rate + replace_rate),
replace,
val
)
)
return val
@staticmethod
def _mutate_int(key, val, options, replace_rate):
k1, k2 = jax.random.split(key, num=2)
r = jax.random.uniform(k1, ())
val = jnp.where(
r < replace_rate,
jax.random.choice(k2, options),
val
)
return val

View File

@@ -1,57 +0,0 @@
from dataclasses import dataclass
import jax
from jax import numpy as jnp, vmap
from .normal import NormalGene, NormalGeneConfig
from core import State, Genome
from utils import unflatten_conns, act, agg
@dataclass(frozen=True)
class RecurrentGeneConfig(NormalGeneConfig):
activate_times: int = 10
def __post_init__(self):
super().__post_init__()
assert self.activate_times > 0
class RecurrentGene(NormalGene):
def __init__(self, config: RecurrentGeneConfig = RecurrentGeneConfig()):
self.config = config
super().__init__(config)
def forward_transform(self, state: State, genome: Genome):
u_conns = unflatten_conns(genome.nodes, genome.conns)
# remove un-enable connections and remove enable attr
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
return genome.nodes, u_conns
def forward(self, state: State, inputs, transformed):
nodes, conns = transformed
batch_act, batch_agg = vmap(act, in_axes=(0, 0, None)), vmap(agg, in_axes=(0, 0, None))
input_idx = state.input_idx
output_idx = state.output_idx
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
weights = conns[0, :]
def body_func(i, values):
values = values.at[input_idx].set(inputs)
nodes_ins = values * weights.T
values = batch_agg(nodes[:, 4], nodes_ins, self.config.aggregation_options) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(nodes[:, 3], values, self.config.activation_options) # z = act(z)
return values
vals = jax.lax.fori_loop(0, self.config.activate_times, body_func, vals)
return vals[output_idx]

View File

@@ -0,0 +1,3 @@
from .base import BaseGenome
from .default import DefaultGenome
from .recurrent import RecurrentGenome

View File

@@ -0,0 +1,66 @@
import jax.numpy as jnp
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
from utils import fetch_first
class BaseGenome:
network_type = None
def __init__(
self,
num_inputs: int,
num_outputs: int,
max_nodes: int,
max_conns: int,
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
):
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.input_idx = jnp.arange(num_inputs)
self.output_idx = jnp.arange(num_inputs, num_inputs + num_outputs)
self.max_nodes = max_nodes
self.max_conns = max_conns
self.node_gene = node_gene
self.conn_gene = conn_gene
def transform(self, nodes, conns):
raise NotImplementedError
def forward(self, inputs, transformed):
raise NotImplementedError
def add_node(self, nodes, new_key: int, attrs):
"""
Add a new node to the genome.
The new node will place at the first NaN row.
"""
exist_keys = nodes[:, 0]
pos = fetch_first(jnp.isnan(exist_keys))
new_nodes = nodes.at[pos, 0].set(new_key)
return new_nodes.at[pos, 1:].set(attrs)
def delete_node_by_pos(self, nodes, pos):
"""
Delete a node from the genome.
Delete the node by its pos in nodes.
"""
return nodes.at[pos].set(jnp.nan)
def add_conn(self, conns, i_key, o_key, enable: bool, attrs):
"""
Add a new connection to the genome.
The new connection will place at the first NaN row.
"""
con_keys = conns[:, 0]
pos = fetch_first(jnp.isnan(con_keys))
new_conns = conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable]))
return new_conns.at[pos, 3:].set(attrs)
def delete_conn_by_pos(self, conns, pos):
"""
Delete a connection from the genome.
Delete the connection by its idx.
"""
return conns.at[pos].set(jnp.nan)

View File

@@ -0,0 +1,75 @@
import jax, jax.numpy as jnp
from utils import unflatten_conns, topological_sort, I_INT
from . import BaseGenome
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
class DefaultGenome(BaseGenome):
"""Default genome class, with the same behavior as the NEAT-Python"""
network_type = 'feedforward'
def __init__(self,
num_inputs: int,
num_outputs: int,
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
):
super().__init__(num_inputs, num_outputs, node_gene, conn_gene)
def transform(self, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
# DONE: Seems like there is a bug in this line
# conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
# modified: exist conn and enable is true
# conn_enable = jnp.where( (~jnp.isnan(u_conns[0])) & (u_conns[0] == 1), True, False)
# advanced modified: when and only when enabled is True
conn_enable = u_conns[0] == 1
# remove enable attr
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
seqs = topological_sort(nodes, conn_enable)
return seqs, nodes, u_conns
def forward(self, inputs, transformed):
cal_seqs, nodes, conns = transformed
N = nodes.shape[0]
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[self.input_idx].set(inputs)
nodes_attrs = nodes[:, 1:]
def cond_fun(carry):
values, idx = carry
return (idx < N) & (cal_seqs[idx] != I_INT)
def body_func(carry):
values, idx = carry
i = cal_seqs[idx]
def hit():
ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values)
# ins = values * weights[:, i]
z = self.node_gene.forward(nodes_attrs[i], ins)
# z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins)
# z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
# z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z)
new_values = values.at[i].set(z)
return new_values
def miss():
return values
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, self.input_idx), miss, hit)
return values, idx + 1
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
return vals[self.output_idx]

View File

@@ -0,0 +1,58 @@
import jax, jax.numpy as jnp
from utils import unflatten_conns
from . import BaseGenome
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
class RecurrentGenome(BaseGenome):
"""Default genome class, with the same behavior as the NEAT-Python"""
network_type = 'recurrent'
def __init__(self,
num_inputs: int,
num_outputs: int,
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
activate_time: int = 10,
):
super().__init__(num_inputs, num_outputs, node_gene, conn_gene)
self.activate_time = activate_time
def transform(self, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
# remove un-enable connections and remove enable attr
conn_enable = u_conns[0] == 1
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
return nodes, u_conns
def forward(self, inputs, transformed):
nodes, conns = transformed
N = nodes.shape[0]
vals = jnp.full((N,), jnp.nan)
nodes_attrs = nodes[:, 1:]
def body_func(_, values):
# set input values
values = values.at[self.input_idx].set(inputs)
# calculate connections
node_ins = jax.vmap(
jax.vmap(
self.conn_gene.forward,
in_axes=(1, None)
),
in_axes=(1, 0)
)(conns, values)
# calculate nodes
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T)
return values
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)
return vals[self.output_idx]

View File

@@ -1,87 +1,38 @@
from typing import Type
import jax, jax.numpy as jnp
from utils import State
from .. import BaseAlgorithm
from .genome import *
from .species import *
from .ga import *
import jax
from jax import numpy as jnp
import numpy as np
class NEAT(BaseAlgorithm):
from config import Config
from core import Algorithm, State, Gene, Genome
from .ga import create_next_generation
from .species import SpeciesInfo, update_species, speciate
def __init__(
self,
genome: BaseGenome,
species: BaseSpecies,
mutation: BaseMutation = DefaultMutation(),
crossover: BaseCrossover = DefaultCrossover(),
):
self.genome = genome
self.species = species
self.mutation = mutation
self.crossover = crossover
class NEAT(Algorithm):
def __init__(self, config: Config, gene_type: Type[Gene]):
self.config = config
self.gene = gene_type(config.gene)
self.forward_func = None
self.tell_func = None
def setup(self, randkey, state: State = State()):
"""initialize the state of the algorithm"""
input_idx = np.arange(self.config.neat.inputs)
output_idx = np.arange(self.config.neat.inputs,
self.config.neat.inputs + self.config.neat.outputs)
state = state.update(
P=self.config.basic.pop_size,
N=self.config.neat.max_nodes,
C=self.config.neat.max_conns,
S=self.config.neat.max_species,
NL=1 + len(self.gene.node_attrs), # node length = (key) + attributes
CL=3 + len(self.gene.conn_attrs), # conn length = (in, out, key) + attributes
max_stagnation=self.config.neat.max_stagnation,
species_elitism=self.config.neat.species_elitism,
spawn_number_change_rate=self.config.neat.spawn_number_change_rate,
genome_elitism=self.config.neat.genome_elitism,
survival_threshold=self.config.neat.survival_threshold,
compatibility_threshold=self.config.neat.compatibility_threshold,
compatibility_disjoint=self.config.neat.compatibility_disjoint,
compatibility_weight=self.config.neat.compatibility_weight,
input_idx=input_idx,
output_idx=output_idx,
def setup(self, randkey):
k1, k2 = jax.random.split(randkey, 2)
return State(
randkey=k1,
generation=0,
next_node_key=max(*self.genome.input_idx, *self.genome.output_idx) + 2,
# inputs nodes, output nodes, 1 hidden node
species=self.species.setup(k2),
)
state = self.gene.setup(state)
pop_genomes = self._initialize_genomes(state)
species_info = SpeciesInfo.initialize(state)
idx2species = jnp.zeros(state.P, dtype=jnp.float32)
center_nodes = jnp.full((state.S, state.N, state.NL), jnp.nan, dtype=jnp.float32)
center_conns = jnp.full((state.S, state.C, state.CL), jnp.nan, dtype=jnp.float32)
center_genomes = Genome(center_nodes, center_conns)
center_genomes = center_genomes.set(0, pop_genomes[0])
generation = 0
next_node_key = max(*state.input_idx, *state.output_idx) + 2
next_species_key = 1
state = state.update(
randkey=randkey,
pop_genomes=pop_genomes,
species_info=species_info,
idx2species=idx2species,
center_genomes=center_genomes,
# avoid jax auto cast from int to float. that would cause re-compilation.
generation=jnp.asarray(generation, dtype=jnp.int32),
next_node_key=jnp.asarray(next_node_key, dtype=jnp.float32),
next_species_key=jnp.asarray(next_species_key, dtype=jnp.float32),
)
return jax.device_put(state)
def ask_algorithm(self, state: State):
return state.pop_genomes
def tell_algorithm(self, state: State, fitness):
state = self.gene.update(state)
def ask(self, state: State):
return self.species.ask(state)
def tell(self, state: State, fitness):
k1, k2, randkey = jax.random.split(state.randkey, 3)
state = state.update(
@@ -89,46 +40,55 @@ class NEAT(Algorithm):
randkey=randkey
)
state, winner, loser, elite_mask = update_species(state, k1, fitness)
state, winner, loser, elite_mask = self.species.update_species(state, fitness, state.generation)
state = create_next_generation(self.config.neat, self.gene, state, k2, winner, loser, elite_mask)
state = self.create_next_generation(k2, state, winner, loser, elite_mask)
state = speciate(self.gene, state)
state = self.species.speciate(state, state.generation)
return state
def forward_transform(self, state: State, genome: Genome):
return self.gene.forward_transform(state, genome)
def transform(self, state: State):
"""transform the genome into a neural network"""
raise NotImplementedError
def forward(self, state: State, inputs, genome: Genome):
return self.gene.forward(state, inputs, genome)
def forward(self, inputs, transformed):
raise NotImplementedError
def _initialize_genomes(self, state):
o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes
o_conns = np.full((state.C, state.CL), np.nan, dtype=np.float32) # original connections
def create_next_generation(self, randkey, state, winner, loser, elite_mask):
# prepare random keys
pop_size = self.species.pop_size
new_node_keys = jnp.arange(pop_size) + state.species.next_node_key
input_idx = state.input_idx
output_idx = state.output_idx
new_node_key = max([*input_idx, *output_idx]) + 1
k1, k2 = jax.random.split(randkey, 2)
crossover_rand_keys = jax.random.split(k1, pop_size)
mutate_rand_keys = jax.random.split(k2, pop_size)
o_nodes[input_idx, 0] = input_idx
o_nodes[output_idx, 0] = output_idx
o_nodes[new_node_key, 0] = new_node_key
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = self.gene.new_node_attrs(state)
o_nodes[new_node_key, 1:] = self.gene.new_node_attrs(state)
wpn, wpc = state.species.pop_nodes[winner], state.species.pop_conns[winner]
lpn, lpc = state.species.pop_nodes[loser], state.species.pop_conns[loser]
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)]
o_conns[input_idx, 0:2] = input_conns # in key, out key
o_conns[input_idx, 2] = True # enabled
o_conns[input_idx, 3:] = self.gene.new_conn_attrs(state)
# batch crossover
n_nodes, n_conns = (jax.vmap(self.crossover, in_axes=(0, None, 0, 0, 0, 0))
(crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc))
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx]
o_conns[output_idx, 0:2] = output_conns # in key, out key
o_conns[output_idx, 2] = True # enabled
o_conns[output_idx, 3:] = self.gene.new_conn_attrs(state)
# batch mutation
m_n_nodes, m_n_conns = (jax.vmap(self.mutation, in_axes=(0, None, 0, 0, 0))
(mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys))
# repeat origin genome for P times to create population
pop_nodes = np.tile(o_nodes, (state.P, 1, 1))
pop_conns = np.tile(o_conns, (state.P, 1, 1))
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
# update next node key
all_nodes_keys = pop_nodes[:, :, 0]
max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys))
next_node_key = max_node_key + 1
return state.update(
species=state.species.update(
pop_nodes=pop_nodes,
pop_conns=pop_conns,
),
next_node_key=next_node_key,
)
return Genome(pop_nodes, pop_conns)

View File

@@ -1,2 +1,2 @@
from .species_info import SpeciesInfo
from .operations import update_species, speciate
from .base import BaseSpecies
from .default import DefaultSpecies

View File

@@ -0,0 +1,14 @@
from utils import State
class BaseSpecies:
def setup(self, randkey):
raise NotImplementedError
def ask(self, state: State):
raise NotImplementedError
def update_species(self, state, fitness, generation):
raise NotImplementedError
def speciate(self, state, generation):
raise NotImplementedError

View File

@@ -0,0 +1,514 @@
import numpy as np
import jax, jax.numpy as jnp
from utils import State, rank_elements, argmin_with_mask, fetch_first
from ..genome import BaseGenome
class DefaultSpecies:
def __init__(self,
genome: BaseGenome,
pop_size,
species_size,
compatibility_disjoint: float = 1.0,
compatibility_weight: float = 0.4,
max_stagnation: int = 15,
species_elitism: int = 2,
spawn_number_change_rate: float = 0.5,
genome_elitism: int = 2,
survival_threshold: float = 0.2,
min_species_size: int = 1,
compatibility_threshold: float = 3.5
):
self.genome = genome
self.pop_size = pop_size
self.species_size = species_size
self.compatibility_disjoint = compatibility_disjoint
self.compatibility_weight = compatibility_weight
self.max_stagnation = max_stagnation
self.species_elitism = species_elitism
self.spawn_number_change_rate = spawn_number_change_rate
self.genome_elitism = genome_elitism
self.survival_threshold = survival_threshold
self.min_species_size = min_species_size
self.compatibility_threshold = compatibility_threshold
self.species_arange = jnp.arange(self.species_size)
def setup(self, randkey):
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome)
species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species
best_fitness = jnp.full((self.species_size,), jnp.nan) # the best fitness of each species
last_improved = jnp.full((self.species_size,), jnp.nan) # the last generation that the species improved
member_count = jnp.full((self.species_size,), jnp.nan) # the number of members of each species
idx2species = jnp.zeros(self.pop_size) # the species index of each individual
# nodes for each center genome of each species
center_nodes = jnp.full((self.species_size, self.genome.max_nodes, self.genome.node_gene.length), jnp.nan)
# connections for each center genome of each species
center_conns = jnp.full((self.species_size, self.genome.max_conns, self.genome.conn_gene.length), jnp.nan)
species_keys = species_keys.at[0].set(0)
best_fitness = best_fitness.at[0].set(-jnp.inf)
last_improved = last_improved.at[0].set(0)
member_count = member_count.at[0].set(self.pop_size)
center_nodes = center_nodes.at[0].set(pop_nodes[0])
center_conns = center_conns.at[0].set(pop_conns[0])
return State(
randkey=randkey,
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
member_count=member_count,
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
next_species_key=1, # 0 is reserved for the first species
)
def ask(self, state):
return state.pop_nodes, state.pop_conns
def update_species(self, state, fitness, generation):
# update the fitness of each species
species_fitness = self.update_species_fitness(state, fitness)
# stagnation species
state, species_fitness = self.stagnation(state, generation, species_fitness)
# sort species_info by their fitness. (also push nan to the end)
sort_indices = jnp.argsort(species_fitness)[::-1]
state = state.update(
species_keys=state.species_keys[sort_indices],
best_fitness=state.best_fitness[sort_indices],
last_improved=state.last_improved[sort_indices],
member_count=state.member_count[sort_indices],
center_nodes=state.center_nodes[sort_indices],
center_conns=state.center_conns[sort_indices],
)
# decide the number of members of each species by their fitness
spawn_number = self.cal_spawn_numbers(state)
k1, k2 = jax.random.split(state.randkey)
# crossover info
winner, loser, elite_mask = self.create_crossover_pair(state, k1, spawn_number, fitness)
return state(randkey=k2), winner, loser, elite_mask
def update_species_fitness(self, state, fitness):
"""
obtain the fitness of the species by the fitness of each individual.
use max criterion.
"""
def aux_func(idx):
s_fitness = jnp.where(state.idx2species == state.species_keys[idx], fitness, -jnp.inf)
val = jnp.max(s_fitness)
return val
return jax.vmap(aux_func)(self.species_arange)
def stagnation(self, state, generation, species_fitness):
"""
stagnation species.
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
elitism species never stagnation
generation: the current generation
"""
def check_stagnation(idx):
# determine whether the species stagnation
st = (
(species_fitness[idx] <= state.best_fitness[
idx]) & # not better than the best fitness of the species
(generation - state.last_improved[idx] > self.max_stagnation) # for a long time
)
# update last_improved and best_fitness
li, bf = jax.lax.cond(
species_fitness[idx] > state.best_fitness[idx],
lambda: (generation, species_fitness[idx]), # update
lambda: (state.last_improved[idx], state.best_fitness[idx]) # not update
)
return st, bf, li
spe_st, best_fitness, last_improved = jax.vmap(check_stagnation)(self.species_arange)
# elite species will not be stagnation
species_rank = rank_elements(species_fitness)
spe_st = jnp.where(species_rank < self.species_elitism, False, spe_st) # elitism never stagnation
# set stagnation species to nan
def update_func(idx):
return jax.lax.cond(
spe_st[idx],
lambda: (
jnp.nan, # species_key
jnp.nan, # best_fitness
jnp.nan, # last_improved
jnp.nan, # member_count
-jnp.inf, # species_fitness
jnp.full_like(center_nodes[idx], jnp.nan), # center_nodes
jnp.full_like(center_conns[idx], jnp.nan), # center_conns
), # stagnation species
lambda: (
species_keys[idx],
best_fitness[idx],
last_improved[idx],
state.member_count[idx],
species_fitness[idx],
center_nodes[idx],
center_conns[idx]
) # not stagnation species
)
(
species_keys,
best_fitness,
last_improved,
member_count,
species_fitness,
center_nodes,
center_conns
) = (
jax.vmap(update_func)(self.species_arange))
return state.update(
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
member_count=member_count,
center_nodes=center_nodes,
center_conns=center_conns,
), species_fitness
def cal_spawn_numbers(self, state):
"""
decide the number of members of each species by their fitness rank.
the species with higher fitness will have more members
Linear ranking selection
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
"""
species_keys = state.species_keys
is_species_valid = ~jnp.isnan(species_keys)
valid_species_num = jnp.sum(is_species_valid)
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
rank_score = valid_species_num - self.species_arange # obtain [3, 2, 1]
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
target_spawn_number = jnp.floor(spawn_number_rate * self.pop_size) # calculate member
# Avoid too much variation of numbers for a species
previous_size = state.member_count
spawn_number = previous_size + (target_spawn_number - previous_size) * self.spawn_number_change_rate
spawn_number = spawn_number.astype(jnp.int32)
# must control the sum of spawn_number to be equal to pop_size
error = state.P - jnp.sum(spawn_number)
# add error to the first species to control the sum of spawn_number
spawn_number = spawn_number.at[0].add(error)
return spawn_number
def create_crossover_pair(self, state, randkey, spawn_number, fitness):
s_idx = self.species_arange
p_idx = jnp.arange(self.pop_size)
def aux_func(key, idx):
members = state.idx2species == state.species_keys[idx]
members_num = jnp.sum(members)
members_fitness = jnp.where(members, fitness, -jnp.inf)
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
survive_size = jnp.floor(self.survival_threshold * members_num).astype(jnp.int32)
select_pro = (p_idx < survive_size) / survive_size
fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, self.pop_size), replace=True, p=select_pro)
# elite
fa = jnp.where(p_idx < self.genome_elitism, sorted_member_indices, fa)
ma = jnp.where(p_idx < self.genome_elitism, sorted_member_indices, ma)
elite = jnp.where(p_idx < self.genome_elitism, True, False)
return fa, ma, elite
fas, mas, elites = jax.vmap(aux_func)(jax.random.split(randkey, self.species_size), s_idx)
spawn_number_cum = jnp.cumsum(spawn_number)
def aux_func(idx):
loc = jnp.argmax(idx < spawn_number_cum)
# elite genomes are at the beginning of the species
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species]
part1, part2, elite_mask = jax.vmap(aux_func)(p_idx)
is_part1_win = fitness[part1] >= fitness[part2]
winner = jnp.where(is_part1_win, part1, part2)
loser = jnp.where(is_part1_win, part2, part1)
return winner, loser, elite_mask
def speciate(self, state, generation):
# prepare distance functions
o2p_distance_func = jax.vmap(self.distance, in_axes=(None, None, 0, 0)) # one to population
# idx to specie key
idx2species = jnp.full((self.pop_size,), jnp.nan) # NaN means not assigned to any species
# the distance between genomes to its center genomes
o2c_distances = jnp.full((self.pop_size,), jnp.inf)
# step 1: find new centers
def cond_func(carry):
# i, idx2species, center_nodes, center_conns, o2c_distances
i, i2s, cns, ccs, o2c = carry
return (
(i < self.species_size) &
(~jnp.isnan(state.species_keys[i]))
) # current species is existing
def body_func(carry):
i, i2s, cns, ccs, o2c = carry
distances = o2p_distance_func(cns, ccs, state.pop_nodes, state.pop_conns)
# find the closest one
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i])
cns = cns.set(i, state.pop_nodes[closest_idx])
ccs = ccs.set(i, state.pop_conns[closest_idx])
# the genome with closest_idx will become the new center, thus its distance to center is 0.
o2c = o2c.at[closest_idx].set(0)
return i + 1, i2s, cns, ccs, o2c
_, idx2species, center_nodes, center_conns, o2c_distances = \
jax.lax.while_loop(cond_func, body_func,
(0, idx2species, state.center_nodes, state.center_conns, o2c_distances))
state = state.update(
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
)
# part 2: assign members to each species
def cond_func(carry):
# i, idx2species, center_nodes, center_conns, species_keys, o2c_distances, next_species_key
i, i2s, cns, ccs, sk, o2c, nsk = carry
current_species_existed = ~jnp.isnan(sk[i])
not_all_assigned = jnp.any(jnp.isnan(i2s))
not_reach_species_upper_bounds = i < self.species_size
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
def body_func(carry):
i, i2s, cns, ccs, sk, o2c, nsk = carry
_, i2s, cns, ccs, sk, o2c, nsk = jax.lax.cond(
jnp.isnan(sk[i]), # whether the current species is existing or not
create_new_species, # if not existing, create a new specie
update_exist_specie, # if existing, update the specie
(i, i2s, cns, ccs, sk, o2c, nsk)
)
return i + 1, i2s, cns, ccs, sk, o2c, nsk
def create_new_species(carry):
i, i2s, cns, ccs, sk, o2c, nsk = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(jnp.isnan(i2s))
# assign it to the new species
# [key, best score, last update generation, member_count]
sk = sk.at[i].set(nsk) # nsk -> next species key
i2s = i2s.at[idx].set(nsk)
o2c = o2c.at[idx].set(0)
# update center genomes
cns = cns.set(i, state.pop_nodes[idx])
ccs = ccs.set(i, state.pop_conns[idx])
# find the members for the new species
i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c)
return i, i2s, cns, ccs, sk, o2c, nsk + 1 # change to next new speciate key
def update_exist_specie(carry):
i, i2s, cns, ccs, sk, o2c, nsk = carry
i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c)
# turn to next species
return i + 1, i2s, cns, ccs, sk, o2c, nsk
def speciate_by_threshold(i, i2s, cns, ccs, sk, o2c):
# distance between such center genome and ppo genomes
o2p_distance = o2p_distance_func(cns[i], ccs[i], state.pop_nodes, state.pop_conns)
close_enough_mask = o2p_distance < self.compatibility_threshold
# when a genome is not assigned or the distance between its current center is bigger than this center
catchable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
mask = close_enough_mask & catchable_mask
# update species info
i2s = jnp.where(mask, sk[i], i2s)
# update distance between centers
o2c = jnp.where(mask, o2p_distance, o2c)
return i2s, o2c
# update idx2species
_, idx2species, center_nodes, center_conns, species_keys, _, next_species_key = jax.lax.while_loop(
cond_func,
body_func,
(0, state.idx2species, state.center_nodes, center_conns, state.species_info.species_keys, o2c_distances,
state.next_species_key)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition can only happen when the number of species is reached species upper bounds
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
# complete info of species which is created in this generation
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.best_fitness)
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.best_fitness)
last_improved = jnp.where(new_created_mask, generation, state.last_improved)
# update members count
def count_members(idx):
return jax.lax.cond(
jnp.isnan(species_keys[idx]), # if the species is not existing
lambda _: jnp.nan, # nan
lambda _: jnp.sum(idx2species == species_keys[idx], dtype=jnp.float32) # count members
)
member_count = jax.vmap(count_members)(self.species_arange)
return state.update(
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
member_count=member_count,
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
next_species_key=next_species_key
)
def distance(self, nodes1, conns1, nodes2, conns2):
"""
The distance between two genomes
"""
return self.node_distance(nodes1, nodes2) + self.conn_distance(conns1, conns2)
def node_distance(self, nodes1, nodes2):
"""
The distance of the nodes part for two genomes
"""
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
# align homologous nodes
# this process is similar to np.intersect1d.
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = jnp.concatenate([nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
# flag location of homologous nodes
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
# calculate the count of non_homologous of two genomes
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
# calculate the distance of homologous nodes
hnd = jax.vmap(self.genome.node_gene.distance, in_axes=(0, 0))(fr, sr)
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
val = non_homologous_cnt * self.compatibility_disjoint + homologous_distance * self.compatibility_weight
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
def conn_distance(self, conns1, conns2):
"""
The distance of the conns part for two genomes
"""
con_cnt1 = jnp.sum(~jnp.isnan(conns1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(conns2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
cons = jnp.concatenate((conns1, conns2), axis=0)
keys = cons[:, :2]
sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate([cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
hcd = jax.vmap(self.genome.conn_gene.distance, in_axes=(0, 0))(fr, sr)
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask)
val = non_homologous_cnt * self.compatibility_disjoint + homologous_distance * self.compatibility_weight
return jnp.where(max_cnt == 0, 0, val / max_cnt)
def initialize_population(pop_size, genome):
o_nodes = np.full((genome.max_nodes, genome.node_gene.length), np.nan) # original nodes
o_conns = np.full((genome.max_conns, genome.conn_gene.length), np.nan) # original connections
input_idx, output_idx = genome.input_idx, genome.output_idx
new_node_key = max([*input_idx, *output_idx]) + 1
o_nodes[input_idx, 0] = genome.input_idx
o_nodes[output_idx, 0] = genome.output_idx
o_nodes[new_node_key, 0] = new_node_key # one hidden node
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = genome.node_gene.new_attrs()
o_nodes[new_node_key, 1:] = genome.node_gene.new_attrs() # one hidden node
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] # input nodes to hidden
o_conns[input_idx, 0:2] = input_conns # in key, out key
o_conns[input_idx, 2] = True # enabled
o_conns[input_idx, 3:] = genome.conn_gene.new_conn_attrs()
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] # hidden to output nodes
o_conns[output_idx, 0:2] = output_conns # in key, out key
o_conns[output_idx, 2] = True # enabled
o_conns[output_idx, 3:] = genome.conn_gene.new_conn_attrs()
# repeat origin genome for P times to create population
pop_nodes = np.tile(o_nodes, (pop_size, 1, 1))
pop_conns = np.tile(o_conns, (pop_size, 1, 1))
return pop_nodes, pop_conns

View File

@@ -1,71 +0,0 @@
from jax import Array, numpy as jnp, vmap
from core import Gene
def distance(gene: Gene, state, genome1, genome2):
return node_distance(gene, state, genome1.nodes, genome2.nodes) + \
connection_distance(gene, state, genome1.conns, genome2.conns)
def node_distance(gene: Gene, state, nodes1: Array, nodes2: Array):
"""
Calculate the distance between nodes of two genomes.
"""
# statistics nodes count of two genomes
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
# align homologous nodes
# this process is similar to np.intersect1d.
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = jnp.concatenate([nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
# flag location of homologous nodes
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
# calculate the count of non_homologous of two genomes
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
# calculate the distance of homologous nodes
hnd = vmap(gene.distance_node, in_axes=(None, 0, 0))(state, fr, sr)
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
def connection_distance(gene: Gene, state, cons1: Array, cons2: Array):
"""
Calculate the distance between connections of two genomes.
Similar process as node_distance.
"""
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
cons = jnp.concatenate((cons1, cons2), axis=0)
keys = cons[:, :2]
sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate([cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
hcd = vmap(gene.distance_conn, in_axes=(None, 0, 0))(state, fr, sr)
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask)
val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight
return jnp.where(max_cnt == 0, 0, val / max_cnt)

View File

@@ -1,319 +0,0 @@
import jax
from jax import numpy as jnp, vmap
from core import Gene, Genome, State
from utils import rank_elements, fetch_first
from .distance import distance
from .species_info import SpeciesInfo
def update_species(state, randkey, fitness):
# update the fitness of each species
species_fitness = update_species_fitness(state, fitness)
# stagnation species
state, species_fitness = stagnation(state, species_fitness)
# sort species_info by their fitness. (push nan to the end)
sort_indices = jnp.argsort(species_fitness)[::-1]
state = state.update(
species_info=state.species_info[sort_indices],
center_genomes=state.center_genomes[sort_indices],
)
# decide the number of members of each species by their fitness
spawn_number = cal_spawn_numbers(state)
# crossover info
winner, loser, elite_mask = create_crossover_pair(state, randkey, spawn_number, fitness)
return state, winner, loser, elite_mask
def update_species_fitness(state, fitness):
"""
obtain the fitness of the species by the fitness of each individual.
use max criterion.
"""
def aux_func(idx):
s_fitness = jnp.where(state.idx2species == state.species_info.species_keys[idx], fitness, -jnp.inf)
f = jnp.max(s_fitness)
return f
return vmap(aux_func)(jnp.arange(state.species_info.size()))
def stagnation(state, species_fitness):
"""
stagnation species.
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
elitism species never stagnation
"""
def aux_func(idx):
s_fitness = species_fitness[idx]
sk, bf, li, _ = state.species_info.get(idx)
st = (s_fitness <= bf) & (state.generation - li > state.max_stagnation)
li = jnp.where(s_fitness > bf, state.generation, li)
bf = jnp.where(s_fitness > bf, s_fitness, bf)
return st, sk, bf, li
spe_st, species_keys, best_fitness, last_improved = vmap(aux_func)(jnp.arange(species_fitness.shape[0]))
# elite species will not be stagnation
species_rank = rank_elements(species_fitness)
spe_st = jnp.where(species_rank < state.species_elitism, False, spe_st) # elitism never stagnation
# set stagnation species to nan
species_keys = jnp.where(spe_st, jnp.nan, species_keys)
best_fitness = jnp.where(spe_st, jnp.nan, best_fitness)
last_improved = jnp.where(spe_st, jnp.nan, last_improved)
member_count = jnp.where(spe_st, jnp.nan, state.species_info.member_count)
species_fitness = jnp.where(spe_st, -jnp.inf, species_fitness)
species_info = SpeciesInfo(species_keys, best_fitness, last_improved, member_count)
# TODO: Simplify the coded
center_nodes = jnp.where(spe_st[:, None, None], jnp.nan, state.center_genomes.nodes)
center_conns = jnp.where(spe_st[:, None, None], jnp.nan, state.center_genomes.conns)
state = state.update(
species_info=species_info,
center_genomes=Genome(center_nodes, center_conns)
)
return state, species_fitness
def cal_spawn_numbers(state):
"""
decide the number of members of each species by their fitness rank.
the species with higher fitness will have more members
Linear ranking selection
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
"""
species_keys = state.species_info.species_keys
is_species_valid = ~jnp.isnan(species_keys)
valid_species_num = jnp.sum(is_species_valid)
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
rank_score = valid_species_num - jnp.arange(species_keys.shape[0]) # obtain [3, 2, 1]
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
target_spawn_number = jnp.floor(spawn_number_rate * state.P) # calculate member
# Avoid too much variation of numbers in a species
previous_size = state.species_info.member_count
spawn_number = previous_size + (target_spawn_number - previous_size) * state.spawn_number_change_rate
# jax.debug.print("previous_size: {}, spawn_number: {}", previous_size, spawn_number)
spawn_number = spawn_number.astype(jnp.int32)
# must control the sum of spawn_number to be equal to pop_size
error = state.P - jnp.sum(spawn_number)
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
return spawn_number
def create_crossover_pair(state, randkey, spawn_number, fitness):
species_size = state.species_info.size()
pop_size = fitness.shape[0]
s_idx = jnp.arange(species_size)
p_idx = jnp.arange(pop_size)
# def aux_func(key, idx):
def aux_func(key, idx):
members = state.idx2species == state.species_info.species_keys[idx]
members_num = jnp.sum(members)
members_fitness = jnp.where(members, fitness, -jnp.inf)
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
elite_size = state.genome_elitism
survive_size = jnp.floor(state.survival_threshold * members_num).astype(jnp.int32)
select_pro = (p_idx < survive_size) / survive_size
fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro)
# elite
fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa)
ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma)
elite = jnp.where(p_idx < elite_size, True, False)
return fa, ma, elite
fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx)
spawn_number_cum = jnp.cumsum(spawn_number)
def aux_func(idx):
loc = jnp.argmax(idx < spawn_number_cum)
# elite genomes are at the beginning of the species
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species]
part1, part2, elite_mask = vmap(aux_func)(p_idx)
is_part1_win = fitness[part1] >= fitness[part2]
winner = jnp.where(is_part1_win, part1, part2)
loser = jnp.where(is_part1_win, part2, part1)
return winner, loser, elite_mask
def speciate(gene: Gene, state: State):
pop_size, species_size = state.idx2species.shape[0], state.species_info.size()
# prepare distance functions
o2p_distance_func = vmap(distance, in_axes=(None, None, None, 0)) # one to population
# idx to specie key
idx2species = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species
# the distance between genomes to its center genomes
o2c_distances = jnp.full((pop_size,), jnp.inf)
# step 1: find new centers
def cond_func(carry):
i, i2s, cgs, o2c = carry
return (i < species_size) & (~jnp.isnan(state.species_info.species_keys[i])) # current species is existing
def body_func(carry):
i, i2s, cgs, o2c = carry
distances = o2p_distance_func(gene, state, cgs[i], state.pop_genomes)
# find the closest one
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i])
cgs = cgs.set(i, state.pop_genomes[closest_idx])
# the genome with closest_idx will become the new center, thus its distance to center is 0.
o2c = o2c.at[closest_idx].set(0)
return i + 1, i2s, cgs, o2c
_, idx2species, center_genomes, o2c_distances = \
jax.lax.while_loop(cond_func, body_func, (0, idx2species, state.center_genomes, o2c_distances))
state = state.update(
idx2species=idx2species,
center_genomes=center_genomes,
)
# part 2: assign members to each species
def cond_func(carry):
i, i2s, cgs, sk, o2c, nsk = carry
current_species_existed = ~jnp.isnan(sk[i])
not_all_assigned = jnp.any(jnp.isnan(i2s))
not_reach_species_upper_bounds = i < species_size
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
def body_func(carry):
i, i2s, cgs, sk, o2c, nsk = carry
_, i2s, cgs, sk, o2c, nsk = jax.lax.cond(
jnp.isnan(sk[i]), # whether the current species is existing or not
create_new_species, # if not existing, create a new specie
update_exist_specie, # if existing, update the specie
(i, i2s, cgs, sk, o2c, nsk)
)
return i + 1, i2s, cgs, sk, o2c, nsk
def create_new_species(carry):
i, i2s, cgs, sk, o2c, nsk = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(jnp.isnan(i2s))
# assign it to the new species
# [key, best score, last update generation, member_count]
sk = sk.at[i].set(nsk)
i2s = i2s.at[idx].set(nsk)
o2c = o2c.at[idx].set(0)
# update center genomes
cgs = cgs.set(i, state.pop_genomes[idx])
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
# when a new species is created, it needs to be updated, thus do not change i
return i + 1, i2s, cgs, sk, o2c, nsk + 1 # change to next new speciate key
def update_exist_specie(carry):
i, i2s, cgs, sk, o2c, nsk = carry
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
# turn to next species
return i + 1, i2s, cgs, sk, o2c, nsk
def speciate_by_threshold(i, i2s, cgs, sk, o2c):
# distance between such center genome and ppo genomes
o2p_distance = o2p_distance_func(gene, state, cgs[i], state.pop_genomes)
close_enough_mask = o2p_distance < state.compatibility_threshold
# when a genome is not assigned or the distance between its current center is bigger than this center
cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
# jax.debug.print("{}", o2p_distance)
mask = close_enough_mask & cacheable_mask
# update species info
i2s = jnp.where(mask, sk[i], i2s)
# update distance between centers
o2c = jnp.where(mask, o2p_distance, o2c)
return i2s, o2c
# update idx2species
_, idx2species, center_genomes, species_keys, _, next_species_key = jax.lax.while_loop(
cond_func,
body_func,
(0, state.idx2species, state.center_genomes, state.species_info.species_keys, o2c_distances,
state.next_species_key)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition can only happen when the number of species is reached species upper bounds
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
# complete info of species which is created in this generation
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species_info.best_fitness)
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species_info.best_fitness)
last_improved = jnp.where(new_created_mask, state.generation, state.species_info.last_improved)
# update members count
def count_members(idx):
key = species_keys[idx]
count = jnp.sum(idx2species == key, dtype=jnp.float32)
count = jnp.where(jnp.isnan(key), jnp.nan, count)
return count
member_count = vmap(count_members)(jnp.arange(species_size))
return state.update(
species_info=SpeciesInfo(species_keys, best_fitness, last_improved, member_count),
idx2species=idx2species,
center_genomes=center_genomes,
next_species_key=next_species_key
)
def argmin_with_mask(arr, mask):
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx

View File

@@ -1,55 +0,0 @@
from jax.tree_util import register_pytree_node_class
import numpy as np
import jax.numpy as jnp
@register_pytree_node_class
class SpeciesInfo:
def __init__(self, species_keys, best_fitness, last_improved, member_count):
self.species_keys = species_keys
self.best_fitness = best_fitness
self.last_improved = last_improved
self.member_count = member_count
@classmethod
def initialize(cls, state):
species_keys = np.full((state.S,), np.nan, dtype=np.float32)
best_fitness = np.full((state.S,), np.nan, dtype=np.float32)
last_improved = np.full((state.S,), np.nan, dtype=np.float32)
member_count = np.full((state.S,), np.nan, dtype=np.float32)
species_keys[0] = 0
best_fitness[0] = -np.inf
last_improved[0] = 0
member_count[0] = state.P
return cls(species_keys, best_fitness, last_improved, member_count)
def __getitem__(self, i):
return SpeciesInfo(self.species_keys[i], self.best_fitness[i], self.last_improved[i], self.member_count[i])
def get(self, i):
return self.species_keys[i], self.best_fitness[i], self.last_improved[i], self.member_count[i]
def set(self, idx, value):
species_keys = self.species_keys.at[idx].set(value[0])
best_fitness = self.best_fitness.at[idx].set(value[1])
last_improved = self.last_improved.at[idx].set(value[2])
member_count = self.member_count.at[idx].set(value[3])
return SpeciesInfo(species_keys, best_fitness, last_improved, member_count)
def remove(self, idx):
return self.set(idx, jnp.array([jnp.nan] * 4))
def size(self):
return self.species_keys.shape[0]
def tree_flatten(self):
children = self.species_keys, self.best_fitness, self.last_improved, self.member_count
aux_data = None
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)