change a lot a lot a lot!!!!!!!

This commit is contained in:
wls2002
2023-07-24 02:16:02 +08:00
parent 48f90c7eef
commit ac295c1921
49 changed files with 1138 additions and 1460 deletions

View File

@@ -1,4 +0,0 @@
from .base import Algorithm
from .state import State
from .neat import NEAT
from .hyperneat import HyperNEAT

View File

@@ -1,17 +0,0 @@
from typing import Callable
from .state import State
EMPTY = lambda *args: args
class Algorithm:
def __init__(self):
self.tell: Callable = EMPTY
self.ask: Callable = EMPTY
self.forward: Callable = EMPTY
self.forward_transform: Callable = EMPTY
def setup(self, randkey, state=State()):
pass

View File

@@ -1,2 +0,0 @@
from .hyperneat import HyperNEAT
from .substrate import BaseSubstrate

View File

@@ -1,70 +0,0 @@
from typing import Type
import jax
import numpy as np
from .substrate import BaseSubstrate, analysis_substrate
from .hyperneat_gene import HyperNEATGene
from algorithm import State, Algorithm, neat
class HyperNEAT(Algorithm):
def __init__(self, config, gene_type: Type[neat.BaseGene], substrate: Type[BaseSubstrate]):
super().__init__()
self.config = config
self.gene_type = gene_type
self.substrate = substrate
self.neat = neat.NEAT(config, gene_type)
self.tell = create_tell(self.neat)
self.forward_transform = create_forward_transform(config, self.neat)
self.forward = HyperNEATGene.create_forward(config)
def setup(self, randkey, state=State()):
state = state.update(
below_threshold=self.config['below_threshold'],
max_weight=self.config['max_weight']
)
state = self.substrate.setup(state, self.config)
h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state)
h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis]
h_conns = np.zeros((correspond_keys.shape[0], 3), dtype=np.float32)
h_conns[:, 0:2] = correspond_keys
state = state.update(
# h is short for hyperneat
h_input_idx=h_input_idx,
h_output_idx=h_output_idx,
h_hidden_idx=h_hidden_idx,
query_coors=query_coors,
correspond_keys=correspond_keys,
h_nodes=h_nodes,
h_conns=h_conns
)
state = self.neat.setup(randkey, state=state)
self.config['h_input_idx'] = h_input_idx
self.config['h_output_idx'] = h_output_idx
return state
def create_tell(neat_instance):
def tell(state, fitness):
return neat_instance.tell(state, fitness)
return tell
def create_forward_transform(config, neat_instance):
def forward_transform(state, nodes, conns):
t = neat_instance.forward_transform(state, nodes, conns)
batch_forward_func = jax.vmap(neat_instance.forward, in_axes=(0, None))
query_res = batch_forward_func(state.query_coors, t) # hyperneat connections
h_nodes = state.h_nodes
h_conns = state.h_conns.at[:, 2:].set(query_res)
return HyperNEATGene.forward_transform(state, h_nodes, h_conns)
return forward_transform

View File

@@ -1,54 +0,0 @@
import jax
from jax import numpy as jnp, vmap
from algorithm.neat import BaseGene
from algorithm.neat.gene import Activation
from algorithm.neat.gene import Aggregation
class HyperNEATGene(BaseGene):
node_attrs = [] # no node attributes
conn_attrs = ['weight']
@staticmethod
def forward_transform(state, nodes, conns):
N = nodes.shape[0]
u_conns = jnp.zeros((N, N), dtype=jnp.float32)
in_keys = jnp.asarray(conns[:, 0], jnp.int32)
out_keys = jnp.asarray(conns[:, 1], jnp.int32)
weights = conns[:, 2]
u_conns = u_conns.at[in_keys, out_keys].set(weights)
return nodes, u_conns
@staticmethod
def create_forward(config):
act = Activation.name2func[config['h_activation']]
agg = Aggregation.name2func[config['h_aggregation']]
batch_act, batch_agg = vmap(act), vmap(agg)
def forward(inputs, transform):
inputs_with_bias = jnp.concatenate((inputs, jnp.ones((1,))), axis=0)
nodes, weights = transform
input_idx = config['h_input_idx']
output_idx = config['h_output_idx']
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
def body_func(i, values):
values = values.at[input_idx].set(inputs_with_bias)
nodes_ins = values * weights.T
values = batch_agg(nodes_ins) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(values) # z = act(z)
return values
vals = jax.lax.fori_loop(0, config['h_activate_times'], body_func, vals)
return vals[output_idx]
return forward

View File

@@ -1,2 +0,0 @@
from .base import BaseSubstrate
from .tools import analysis_substrate

View File

@@ -1,12 +0,0 @@
import numpy as np
class BaseSubstrate:
@staticmethod
def setup(state, config):
return state.update(
input_coors=np.asarray(config['input_coors'], dtype=np.float32),
output_coors=np.asarray(config['output_coors'], dtype=np.float32),
hidden_coors=np.asarray(config['hidden_coors'], dtype=np.float32),
)

View File

@@ -1,53 +0,0 @@
from typing import Type
import numpy as np
from .base import BaseSubstrate
def analysis_substrate(state):
cd = state.input_coors.shape[1] # coordinate dimensions
si = state.input_coors.shape[0] # input coordinate size
so = state.output_coors.shape[0] # output coordinate size
sh = state.hidden_coors.shape[0] # hidden coordinate size
input_idx = np.arange(si)
output_idx = np.arange(si, si + so)
hidden_idx = np.arange(si + so, si + so + sh)
total_conns = si * sh + sh * sh + sh * so
query_coors = np.zeros((total_conns, cd * 2))
correspond_keys = np.zeros((total_conns, 2))
# connect input to hidden
aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, state.input_coors, state.hidden_coors)
query_coors[0: si * sh, :] = aux_coors
correspond_keys[0: si * sh, :] = aux_keys
# connect hidden to hidden
aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, state.hidden_coors, state.hidden_coors)
query_coors[si * sh: si * sh + sh * sh, :] = aux_coors
correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys
# connect hidden to output
aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, state.hidden_coors, state.output_coors)
query_coors[si * sh + sh * sh:, :] = aux_coors
correspond_keys[si * sh + sh * sh:, :] = aux_keys
return input_idx, output_idx, hidden_idx, query_coors, correspond_keys
def cartesian_product(keys1, keys2, coors1, coors2):
len1 = keys1.shape[0]
len2 = keys2.shape[0]
repeated_coors1 = np.repeat(coors1, len2, axis=0)
repeated_keys1 = np.repeat(keys1, len2)
tiled_coors2 = np.tile(coors2, (len1, 1))
tiled_keys2 = np.tile(keys2, len1)
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
return new_coors, correspond_keys

View File

@@ -1,2 +0,0 @@
from .neat import NEAT
from .gene import BaseGene, NormalGene, RecurrentGene

View File

@@ -0,0 +1,2 @@
from .crossover import crossover
from .mutate import create_mutate

View File

@@ -1,8 +1,10 @@
import jax
from jax import jit, Array, numpy as jnp
from jax import Array, numpy as jnp
from core import Genome
def crossover(randkey, nodes1: Array, conns1: Array, nodes2: Array, conns2: Array):
def crossover(randkey, genome1: Genome, genome2: Genome):
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
@@ -10,20 +12,22 @@ def crossover(randkey, nodes1: Array, conns1: Array, nodes2: Array, conns2: Arra
randkey_1, randkey_2, key= jax.random.split(randkey, 3)
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
keys1, keys2 = genome1.nodes[:, 0], genome2.nodes[:, 0]
# make homologous genes align in nodes2 align with nodes1
nodes2 = align_array(keys1, keys2, nodes2, False)
nodes2 = align_array(keys1, keys2, genome2.nodes, False)
nodes1 = genome1.nodes
# For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
cons2 = align_array(con_keys1, con_keys2, conns2, True)
new_cons = jnp.where(jnp.isnan(conns1) | jnp.isnan(cons2), conns1, crossover_gene(randkey_2, conns1, cons2))
con_keys1, con_keys2 = genome1.conns[:, :2], genome2.conns[:, :2]
conns2 = align_array(con_keys1, con_keys2, genome2.conns, True)
conns1 = genome1.conns
return new_nodes, new_cons
new_cons = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, crossover_gene(randkey_2, conns1, conns2))
return genome1.update(new_nodes, new_cons)
def align_array(seq1: Array, seq2: Array, ar2: Array, is_conn: bool) -> Array:

189
algorithm/neat/ga/mutate.py Normal file
View File

@@ -0,0 +1,189 @@
from typing import Tuple, Type
import jax
from jax import Array, numpy as jnp, vmap
from config import NeatConfig
from core import State, Gene, Genome
from utils import check_cycles, fetch_random, fetch_first, I_INT, unflatten_conns
def create_mutate(config: NeatConfig, gene_type: Type[Gene]):
"""
Create function to mutate a single genome
"""
def mutate_structure(state: State, randkey, genome: Genome, new_node_key):
def mutate_add_node(key_, genome_: Genome):
i_key, o_key, idx = choice_connection_key(key_, genome_.conns)
def nothing():
return genome_
def successful_add_node():
# disable the connection
new_genome = genome_.update_conns(genome_.conns.at[idx, 2].set(False))
# add a new node
new_genome = new_genome.add_node(new_node_key, gene_type.new_node_attrs(state))
# add two new connections
new_genome = new_genome.add_conn(i_key, new_node_key, True, gene_type.new_conn_attrs(state))
new_genome = new_genome.add_conn(new_node_key, o_key, True, gene_type.new_conn_attrs(state))
return new_genome
# if from_idx == I_INT, that means no connection exist, do nothing
return jax.lax.cond(idx == I_INT, nothing, successful_add_node)
def mutate_delete_node(key_, genome_: Genome):
# TODO: Do we really need to delete a node?
# randomly choose a node
key, idx = choice_node_key(key_, genome_.nodes, state.input_idx, state.output_idx,
allow_input_keys=False, allow_output_keys=False)
def nothing():
return genome_
def successful_delete_node():
# delete the node
new_genome = genome_.delete_node_by_pos(idx)
# delete all connections
new_conns = jnp.where(((new_genome.conns[:, 0] == key) | (new_genome.conns[:, 1] == key))[:, None],
jnp.nan, new_genome.conns)
return new_genome.update_conns(new_conns)
return jax.lax.cond(idx == I_INT, nothing, successful_delete_node)
def mutate_add_conn(key_, genome_: Genome):
# randomly choose two nodes
k1_, k2_ = jax.random.split(key_, num=2)
i_key, from_idx = choice_node_key(k1_, genome_.nodes, state.input_idx, state.output_idx,
allow_input_keys=True, allow_output_keys=True)
o_key, to_idx = choice_node_key(k2_, genome_.nodes, state.input_idx, state.output_idx,
allow_input_keys=False, allow_output_keys=True)
conn_pos = fetch_first((genome_.conns[:, 0] == i_key) & (genome_.conns[:, 1] == o_key))
def nothing():
return genome_
def successful():
return genome_.add_conn(i_key, o_key, True, gene_type.new_conn_attrs(state))
def already_exist():
return genome_.update_conns(genome_.conns.at[conn_pos, 2].set(True))
is_already_exist = conn_pos != I_INT
if config.network_type == 'feedforward':
u_cons = unflatten_conns(genome_.nodes, genome_.conns)
cons_exist = jnp.where(~jnp.isnan(u_cons[0, :, :]), True, False)
is_cycle = check_cycles(genome_.nodes, cons_exist, from_idx, to_idx)
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
return jax.lax.switch(choice, [already_exist, nothing, successful])
elif config.network_type == 'recurrent':
return jax.lax.cond(is_already_exist, already_exist, successful)
else:
raise ValueError(f"Invalid network type: {config.network_type}")
def mutate_delete_conn(key_, genome_: Genome):
# randomly choose a connection
i_key, o_key, idx = choice_connection_key(key_, genome_.conns)
def nothing():
return genome_
def successfully_delete_connection():
return genome_.delete_conn_by_pos(idx)
return jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection)
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
def no(k, g):
return g
genome = jax.lax.cond(r1 < config.node_add, mutate_add_node, no, k1, genome)
genome = jax.lax.cond(r2 < config.node_delete, mutate_delete_node, no, k2, genome)
genome = jax.lax.cond(r3 < config.conn_add, mutate_add_conn, no, k3, genome)
genome = jax.lax.cond(r4 < config.conn_delete, mutate_delete_conn, no, k4, genome)
return genome
def mutate_values(state: State, randkey, genome: Genome):
k1, k2 = jax.random.split(randkey, num=2)
nodes_keys = jax.random.split(k1, num=genome.nodes.shape[0])
conns_keys = jax.random.split(k2, num=genome.conns.shape[0])
nodes_attrs, conns_attrs = genome.nodes[:, 1:], genome.conns[:, 3:]
new_nodes_attrs = vmap(gene_type.mutate_node, in_axes=(None, 0, 0))(state, nodes_attrs, nodes_keys)
new_conns_attrs = vmap(gene_type.mutate_conn, in_axes=(None, 0, 0))(state, conns_attrs, conns_keys)
# nan nodes not changed
new_nodes_attrs = jnp.where(jnp.isnan(nodes_attrs), jnp.nan, new_nodes_attrs)
new_conns_attrs = jnp.where(jnp.isnan(conns_attrs), jnp.nan, new_conns_attrs)
new_nodes = genome.nodes.at[:, 1:].set(new_nodes_attrs)
new_conns = genome.conns.at[:, 3:].set(new_conns_attrs)
return genome.update(new_nodes, new_conns)
def mutate(state, randkey, genome: Genome, new_node_key):
k1, k2 = jax.random.split(randkey)
genome = mutate_structure(state, k1, genome, new_node_key)
genome = mutate_values(state, k2, genome)
return genome
return mutate
def choice_node_key(rand_key: Array, nodes: Array,
input_keys: Array, output_keys: Array,
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
"""
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
:param rand_key:
:param nodes:
:param input_keys:
:param output_keys:
:param allow_input_keys:
:param allow_output_keys:
:return: return its key and position(idx)
"""
node_keys = nodes[:, 0]
mask = ~jnp.isnan(node_keys)
if not allow_input_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys))
if not allow_output_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys))
idx = fetch_random(rand_key, mask)
key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan)
return key, idx
def choice_connection_key(rand_key: Array, conns: Array):
"""
Randomly choose a connection key from the given connections.
:return: i_key, o_key, idx
"""
idx = fetch_random(rand_key, ~jnp.isnan(conns[:, 0]))
i_key = jnp.where(idx != I_INT, conns[idx, 0], jnp.nan)
o_key = jnp.where(idx != I_INT, conns[idx, 1], jnp.nan)
return i_key, o_key, idx

View File

@@ -1,6 +1 @@
from .base import BaseGene
from .normal import NormalGene
from .activation import Activation
from .aggregation import Aggregation
from .recurrent import RecurrentGene
from .normal import NormalGene, NormalGeneConfig

View File

@@ -1,42 +0,0 @@
from jax import Array, numpy as jnp, vmap
class BaseGene:
node_attrs = []
conn_attrs = []
@staticmethod
def setup(state, config):
return state
@staticmethod
def new_node_attrs(state):
return jnp.zeros(0)
@staticmethod
def new_conn_attrs(state):
return jnp.zeros(0)
@staticmethod
def mutate_node(state, attrs: Array, key):
return attrs
@staticmethod
def mutate_conn(state, attrs: Array, key):
return attrs
@staticmethod
def distance_node(state, node1: Array, node2: Array):
return node1
@staticmethod
def distance_conn(state, conn1: Array, conn2: Array):
return conn1
@staticmethod
def forward_transform(state, nodes, conns):
return nodes, conns
@staticmethod
def create_forward(config):
return None

View File

@@ -1,45 +1,100 @@
from dataclasses import dataclass
from typing import Tuple
import jax
from jax import Array, numpy as jnp
from .base import BaseGene
from .activation import Activation
from .aggregation import Aggregation
from algorithm.utils import unflatten_connections, I_INT
from ..genome import topological_sort
from config import GeneConfig
from core import Gene, Genome, State
from utils import Activation, Aggregation, unflatten_conns, topological_sort, I_INT
class NormalGene(BaseGene):
@dataclass(frozen=True)
class NormalGeneConfig(GeneConfig):
bias_init_mean: float = 0.0
bias_init_std: float = 1.0
bias_mutate_power: float = 0.5
bias_mutate_rate: float = 0.7
bias_replace_rate: float = 0.1
response_init_mean: float = 1.0
response_init_std: float = 0.0
response_mutate_power: float = 0.5
response_mutate_rate: float = 0.7
response_replace_rate: float = 0.1
activation_default: str = 'sigmoid'
activation_options: Tuple[str] = ('sigmoid',)
activation_replace_rate: float = 0.1
aggregation_default: str = 'sum'
aggregation_options: Tuple[str] = ('sum',)
aggregation_replace_rate: float = 0.1
weight_init_mean: float = 0.0
weight_init_std: float = 1.0
weight_mutate_power: float = 0.5
weight_mutate_rate: float = 0.8
weight_replace_rate: float = 0.1
def __post_init__(self):
assert self.bias_init_std >= 0.0
assert self.bias_mutate_power >= 0.0
assert self.bias_mutate_rate >= 0.0
assert self.bias_replace_rate >= 0.0
assert self.response_init_std >= 0.0
assert self.response_mutate_power >= 0.0
assert self.response_mutate_rate >= 0.0
assert self.response_replace_rate >= 0.0
assert self.activation_default == self.activation_options[0]
for name in self.activation_options:
assert name in Activation.name2func, f"Activation function: {name} not found"
assert self.aggregation_default == self.aggregation_options[0]
assert self.aggregation_default in Aggregation.name2func, \
f"Aggregation function: {self.aggregation_default} not found"
for name in self.aggregation_options:
assert name in Aggregation.name2func, f"Aggregation function: {name} not found"
class NormalGene(Gene):
node_attrs = ['bias', 'response', 'aggregation', 'activation']
conn_attrs = ['weight']
@staticmethod
def setup(state, config):
def setup(config: NormalGeneConfig, state: State = State()):
return state.update(
bias_init_mean=config['bias_init_mean'],
bias_init_std=config['bias_init_std'],
bias_mutate_power=config['bias_mutate_power'],
bias_mutate_rate=config['bias_mutate_rate'],
bias_replace_rate=config['bias_replace_rate'],
bias_init_mean=config.bias_init_mean,
bias_init_std=config.bias_init_std,
bias_mutate_power=config.bias_mutate_power,
bias_mutate_rate=config.bias_mutate_rate,
bias_replace_rate=config.bias_replace_rate,
response_init_mean=config['response_init_mean'],
response_init_std=config['response_init_std'],
response_mutate_power=config['response_mutate_power'],
response_mutate_rate=config['response_mutate_rate'],
response_replace_rate=config['response_replace_rate'],
response_init_mean=config.response_init_mean,
response_init_std=config.response_init_std,
response_mutate_power=config.response_mutate_power,
response_mutate_rate=config.response_mutate_rate,
response_replace_rate=config.response_replace_rate,
activation_default=config['activation_default'],
activation_options=config['activation_options'],
activation_replace_rate=config['activation_replace_rate'],
activation_replace_rate=config.activation_replace_rate,
activation_default=0,
activation_options=jnp.arange(len(config.activation_options)),
aggregation_default=config['aggregation_default'],
aggregation_options=config['aggregation_options'],
aggregation_replace_rate=config['aggregation_replace_rate'],
aggregation_replace_rate=config.aggregation_replace_rate,
aggregation_default=0,
aggregation_options=jnp.arange(len(config.aggregation_options)),
weight_init_mean=config['weight_init_mean'],
weight_init_std=config['weight_init_std'],
weight_mutate_power=config['weight_mutate_power'],
weight_mutate_rate=config['weight_mutate_rate'],
weight_replace_rate=config['weight_replace_rate'],
weight_init_mean=config.weight_init_mean,
weight_init_std=config.weight_init_std,
weight_mutate_power=config.weight_mutate_power,
weight_mutate_rate=config.weight_mutate_rate,
weight_replace_rate=config.weight_replace_rate,
)
@staticmethod
@@ -84,20 +139,20 @@ class NormalGene(BaseGene):
return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight
@staticmethod
def forward_transform(state, nodes, conns):
u_conns = unflatten_connections(nodes, conns)
def forward_transform(state: State, genome: Genome):
u_conns = unflatten_conns(genome.nodes, genome.conns)
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
# remove enable attr
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
seqs = topological_sort(nodes, conn_enable)
seqs = topological_sort(genome.nodes, conn_enable)
return seqs, nodes, u_conns
return seqs, genome.nodes, u_conns
@staticmethod
def create_forward(config):
config['activation_funcs'] = [Activation.name2func[name] for name in config['activation_option_names']]
config['aggregation_funcs'] = [Aggregation.name2func[name] for name in config['aggregation_option_names']]
def create_forward(state: State, config: NormalGeneConfig):
activation_funcs = [Activation.name2func[name] for name in config.activation_options]
aggregation_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
def act(idx, z):
"""
@@ -105,7 +160,7 @@ class NormalGene(BaseGene):
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, config['activation_funcs'], z)
res = jax.lax.switch(idx, activation_funcs, z)
return res
def agg(idx, z):
@@ -118,14 +173,13 @@ class NormalGene(BaseGene):
return 0.
def not_all_nan():
return jax.lax.switch(idx, config['aggregation_funcs'], z)
return jax.lax.switch(idx, aggregation_funcs, z)
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
def forward(inputs, transform) -> Array:
def forward(inputs, transformed) -> Array:
"""
jax forward for single input shaped (input_num, )
nodes, connections are a single genome
forward for single input shaped (input_num, )
:argument inputs: (input_num, )
:argument cal_seqs: (N, )
@@ -135,10 +189,10 @@ class NormalGene(BaseGene):
:return (output_num, )
"""
cal_seqs, nodes, cons = transform
cal_seqs, nodes, cons = transformed
input_idx = config['input_idx']
output_idx = config['output_idx']
input_idx = state.input_idx
output_idx = state.output_idx
N = nodes.shape[0]
ini_vals = jnp.full((N,), jnp.nan)

View File

@@ -1,90 +0,0 @@
import jax
from jax import Array, numpy as jnp, vmap
from .normal import NormalGene
from .activation import Activation
from .aggregation import Aggregation
from algorithm.utils import unflatten_connections
class RecurrentGene(NormalGene):
@staticmethod
def forward_transform(state, nodes, conns):
u_conns = unflatten_connections(nodes, conns)
# remove un-enable connections and remove enable attr
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
return nodes, u_conns
@staticmethod
def create_forward(config):
config['activation_funcs'] = [Activation.name2func[name] for name in config['activation_option_names']]
config['aggregation_funcs'] = [Aggregation.name2func[name] for name in config['aggregation_option_names']]
def act(idx, z):
"""
calculate activation function for each node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, config['activation_funcs'], z)
return res
def agg(idx, z):
"""
calculate activation function for inputs of node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
def all_nan():
return 0.
def not_all_nan():
return jax.lax.switch(idx, config['aggregation_funcs'], z)
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
batch_act, batch_agg = vmap(act), vmap(agg)
def forward(inputs, transform) -> Array:
"""
jax forward for single input shaped (input_num, )
nodes, connections are a single genome
:argument inputs: (input_num, )
:argument cal_seqs: (N, )
:argument nodes: (N, 5)
:argument connections: (2, N, N)
:return (output_num, )
"""
nodes, cons = transform
input_idx = config['input_idx']
output_idx = config['output_idx']
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
weights = cons[0, :]
def body_func(i, values):
values = values.at[input_idx].set(inputs)
nodes_ins = values * weights.T
values = batch_agg(nodes[:, 4], nodes_ins) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(nodes[:, 3], values) # z = act(z)
return values
# for i in range(config['activate_times']):
# vals = body_func(i, vals)
#
# return vals[output_idx]
vals = jax.lax.fori_loop(0, config['activate_times'], body_func, vals)
return vals[output_idx]
return forward

View File

@@ -1,5 +0,0 @@
from .basic import initialize_genomes
from .mutate import create_mutate
from .distance import create_distance
from .crossover import crossover
from .graph import topological_sort

View File

@@ -1,111 +0,0 @@
from typing import Type, Tuple
import numpy as np
import jax
from jax import Array, numpy as jnp
from algorithm import State
from ..gene import BaseGene
from algorithm.utils import fetch_first
def initialize_genomes(state: State, gene_type: Type[BaseGene]):
o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes
o_conns = np.full((state.C, state.CL), np.nan, dtype=np.float32) # original connections
input_idx = state.input_idx
output_idx = state.output_idx
new_node_key = max([*input_idx, *output_idx]) + 1
o_nodes[input_idx, 0] = input_idx
o_nodes[output_idx, 0] = output_idx
o_nodes[new_node_key, 0] = new_node_key
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = jax.device_get(gene_type.new_node_attrs(state))
o_nodes[new_node_key, 1:] = jax.device_get(gene_type.new_node_attrs(state))
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)]
o_conns[input_idx, 0:2] = input_conns # in key, out key
o_conns[input_idx, 2] = True # enabled
o_conns[input_idx, 3:] = jax.device_get(gene_type.new_conn_attrs(state))
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx]
o_conns[output_idx, 0:2] = output_conns # in key, out key
o_conns[output_idx, 2] = True # enabled
o_conns[output_idx, 3:] = jax.device_get(gene_type.new_conn_attrs(state))
# repeat origin genome for P times to create population
pop_nodes = np.tile(o_nodes, (state.P, 1, 1))
pop_conns = np.tile(o_conns, (state.P, 1, 1))
return jax.device_put([pop_nodes, pop_conns])
def count(nodes: Array, cons: Array):
"""
Count how many nodes and connections are in the genome.
"""
node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0]))
cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
return node_cnt, cons_cnt
def add_node(nodes: Array, cons: Array, new_key: int, attrs: Array) -> Tuple[Array, Array]:
"""
Add a new node to the genome.
The new node will place at the first NaN row.
"""
exist_keys = nodes[:, 0]
idx = fetch_first(jnp.isnan(exist_keys))
nodes = nodes.at[idx, 0].set(new_key)
nodes = nodes.at[idx, 1:].set(attrs)
return nodes, cons
def delete_node(nodes: Array, cons: Array, node_key: Array) -> Tuple[Array, Array]:
"""
Delete a node from the genome. Only delete the node, regardless of connections.
Delete the node by its key.
"""
node_keys = nodes[:, 0]
idx = fetch_first(node_keys == node_key)
return delete_node_by_idx(nodes, cons, idx)
def delete_node_by_idx(nodes: Array, cons: Array, idx: Array) -> Tuple[Array, Array]:
"""
Delete a node from the genome. Only delete the node, regardless of connections.
Delete the node by its idx.
"""
nodes = nodes.at[idx].set(np.nan)
return nodes, cons
def add_connection(nodes: Array, cons: Array, i_key: Array, o_key: Array, enable: bool, attrs: Array) -> Tuple[
Array, Array]:
"""
Add a new connection to the genome.
The new connection will place at the first NaN row.
"""
con_keys = cons[:, 0]
idx = fetch_first(jnp.isnan(con_keys))
cons = cons.at[idx, 0:3].set(jnp.array([i_key, o_key, enable]))
cons = cons.at[idx, 3:].set(attrs)
return nodes, cons
def delete_connection(nodes: Array, cons: Array, i_key: Array, o_key: Array) -> Tuple[Array, Array]:
"""
Delete a connection from the genome.
Delete the connection by its input and output node keys.
"""
idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
return delete_connection_by_idx(nodes, cons, idx)
def delete_connection_by_idx(nodes: Array, cons: Array, idx: Array) -> Tuple[Array, Array]:
"""
Delete a connection from the genome.
Delete the connection by its idx.
"""
cons = cons.at[idx].set(np.nan)
return nodes, cons

View File

@@ -1,205 +0,0 @@
from typing import Dict, Tuple, Type
import jax
from jax import Array, numpy as jnp, vmap
from algorithm import State
from .basic import add_node, add_connection, delete_node_by_idx, delete_connection_by_idx
from .graph import check_cycles
from algorithm.utils import fetch_random, fetch_first, I_INT, unflatten_connections
from ..gene import BaseGene
def create_mutate(config: Dict, gene_type: Type[BaseGene]):
"""
Create function to mutate a single genome
"""
def mutate_structure(state: State, randkey, nodes, conns, new_node_key):
def mutate_add_node(key_, nodes_, conns_):
i_key, o_key, idx = choice_connection_key(key_, nodes_, conns_)
def nothing():
return nodes_, conns_
def successful_add_node():
# disable the connection
aux_nodes, aux_conns = nodes_, conns_
# set enable to false
aux_conns = aux_conns.at[idx, 2].set(False)
# add a new node
aux_nodes, aux_conns = add_node(aux_nodes, aux_conns, new_node_key, gene_type.new_node_attrs(state))
# add two new connections
aux_nodes, aux_conns = add_connection(aux_nodes, aux_conns, i_key, new_node_key, True,
gene_type.new_conn_attrs(state))
aux_nodes, aux_conns = add_connection(aux_nodes, aux_conns, new_node_key, o_key, True,
gene_type.new_conn_attrs(state))
return aux_nodes, aux_conns
# if from_idx == I_INT, that means no connection exist, do nothing
new_nodes, new_conns = jax.lax.cond(idx == I_INT, nothing, successful_add_node)
return new_nodes, new_conns
def mutate_delete_node(key_, nodes_, conns_):
# TODO: Do we really need to delete a node?
# randomly choose a node
key, idx = choice_node_key(key_, nodes_, config['input_idx'], config['output_idx'],
allow_input_keys=False, allow_output_keys=False)
def nothing():
return nodes_, conns_
def successful_delete_node():
# delete the node
aux_nodes, aux_cons = delete_node_by_idx(nodes_, conns_, idx)
# delete all connections
aux_cons = jnp.where(((aux_cons[:, 0] == key) | (aux_cons[:, 1] == key))[:, None],
jnp.nan, aux_cons)
return aux_nodes, aux_cons
return jax.lax.cond(idx == I_INT, nothing, successful_delete_node)
def mutate_add_conn(key_, nodes_, conns_):
# randomly choose two nodes
k1_, k2_ = jax.random.split(key_, num=2)
i_key, from_idx = choice_node_key(k1_, nodes_, config['input_idx'], config['output_idx'],
allow_input_keys=True, allow_output_keys=True)
o_key, to_idx = choice_node_key(k2_, nodes_, config['input_idx'], config['output_idx'],
allow_input_keys=False, allow_output_keys=True)
con_idx = fetch_first((conns_[:, 0] == i_key) & (conns_[:, 1] == o_key))
def nothing():
return nodes_, conns_
def successful():
new_nodes, new_cons = add_connection(nodes_, conns_, i_key, o_key, True, gene_type.new_conn_attrs(state))
return new_nodes, new_cons
def already_exist():
new_cons = conns_.at[con_idx, 2].set(True)
return nodes_, new_cons
is_already_exist = con_idx != I_INT
if config['network_type'] == 'feedforward':
u_cons = unflatten_connections(nodes_, conns_)
cons_exist = jnp.where(~jnp.isnan(u_cons[0, :, :]), True, False)
is_cycle = check_cycles(nodes_, cons_exist, from_idx, to_idx)
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
return jax.lax.switch(choice, [already_exist, nothing, successful])
elif config['network_type'] == 'recurrent':
return jax.lax.cond(is_already_exist, already_exist, successful)
else:
raise ValueError(f"Invalid network type: {config['network_type']}")
def mutate_delete_conn(key_, nodes_, conns_):
# randomly choose a connection
i_key, o_key, idx = choice_connection_key(key_, nodes_, conns_)
def nothing():
return nodes_, conns_
def successfully_delete_connection():
return delete_connection_by_idx(nodes_, conns_, idx)
return jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection)
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
def no(k, n, c):
return n, c
nodes, conns = jax.lax.cond(r1 < config['node_add_prob'], mutate_add_node, no, k1, nodes, conns)
nodes, conns = jax.lax.cond(r2 < config['node_delete_prob'], mutate_delete_node, no, k2, nodes, conns)
nodes, conns = jax.lax.cond(r3 < config['conn_add_prob'], mutate_add_conn, no, k3, nodes, conns)
nodes, conns = jax.lax.cond(r4 < config['conn_delete_prob'], mutate_delete_conn, no, k4, nodes, conns)
return nodes, conns
def mutate_values(state: State, randkey, nodes, conns):
k1, k2 = jax.random.split(randkey, num=2)
nodes_keys = jax.random.split(k1, num=nodes.shape[0])
conns_keys = jax.random.split(k2, num=conns.shape[0])
nodes_attrs, conns_attrs = nodes[:, 1:], conns[:, 3:]
new_nodes_attrs = vmap(gene_type.mutate_node, in_axes=(None, 0, 0))(state, nodes_attrs, nodes_keys)
new_conns_attrs = vmap(gene_type.mutate_conn, in_axes=(None, 0, 0))(state, conns_attrs, conns_keys)
# nan nodes not changed
new_nodes_attrs = jnp.where(jnp.isnan(nodes_attrs), jnp.nan, new_nodes_attrs)
new_conns_attrs = jnp.where(jnp.isnan(conns_attrs), jnp.nan, new_conns_attrs)
new_nodes = nodes.at[:, 1:].set(new_nodes_attrs)
new_conns = conns.at[:, 3:].set(new_conns_attrs)
return new_nodes, new_conns
def mutate(state, randkey, nodes, conns, new_node_key):
k1, k2 = jax.random.split(randkey)
nodes, conns = mutate_structure(state, k1, nodes, conns, new_node_key)
nodes, conns = mutate_values(state, k2, nodes, conns)
return nodes, conns
return mutate
def choice_node_key(rand_key: Array, nodes: Array,
input_keys: Array, output_keys: Array,
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
"""
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
:param rand_key:
:param nodes:
:param input_keys:
:param output_keys:
:param allow_input_keys:
:param allow_output_keys:
:return: return its key and position(idx)
"""
node_keys = nodes[:, 0]
mask = ~jnp.isnan(node_keys)
if not allow_input_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys))
if not allow_output_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys))
idx = fetch_random(rand_key, mask)
key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan)
return key, idx
def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[Array, Array, Array]:
"""
Randomly choose a connection key from the given connections.
:param rand_key:
:param nodes:
:param cons:
:return: i_key, o_key, idx
"""
idx = fetch_random(rand_key, ~jnp.isnan(cons[:, 0]))
i_key = jnp.where(idx != I_INT, cons[idx, 0], jnp.nan)
o_key = jnp.where(idx != I_INT, cons[idx, 1], jnp.nan)
return i_key, o_key, idx

View File

@@ -1,67 +1,84 @@
from typing import Type
import jax
import jax.numpy as jnp
from jax import numpy as jnp, Array, vmap
import numpy as np
from algorithm import Algorithm, State
from .gene import BaseGene
from .genome import initialize_genomes
from .population import create_tell
from config import Config
from core import Algorithm, State, Gene, Genome
from .ga import crossover, create_mutate
from .species import update_species, create_speciate
class NEAT(Algorithm):
def __init__(self, config, gene_type: Type[BaseGene]):
super().__init__()
def __init__(self, config: Config, gene_type: Type[Gene]):
self.config = config
self.gene_type = gene_type
self.tell = create_tell(config, self.gene_type)
self.ask = None
self.forward = self.gene_type.create_forward(config)
self.forward_transform = self.gene_type.forward_transform
self.forward_func = None
self.tell_func = None
def setup(self, randkey, state: State = State()):
"""initialize the state of the algorithm"""
input_idx = np.arange(self.config.basic.num_inputs)
output_idx = np.arange(self.config.basic.num_inputs,
self.config.basic.num_inputs + self.config.basic.num_outputs)
def setup(self, randkey, state=State()):
state = state.update(
P=self.config['pop_size'],
N=self.config['maximum_nodes'],
C=self.config['maximum_conns'],
S=self.config['maximum_species'],
P=self.config.basic.pop_size,
N=self.config.neat.maximum_nodes,
C=self.config.neat.maximum_conns,
S=self.config.neat.maximum_species,
NL=1 + len(self.gene_type.node_attrs), # node length = (key) + attributes
CL=3 + len(self.gene_type.conn_attrs), # conn length = (in, out, key) + attributes
input_idx=self.config['input_idx'],
output_idx=self.config['output_idx'],
max_stagnation=self.config['max_stagnation'],
species_elitism=self.config['species_elitism'],
spawn_number_change_rate=self.config['spawn_number_change_rate'],
genome_elitism=self.config['genome_elitism'],
survival_threshold=self.config['survival_threshold'],
compatibility_threshold=self.config['compatibility_threshold'],
max_stagnation=self.config.neat.max_stagnation,
species_elitism=self.config.neat.species_elitism,
spawn_number_change_rate=self.config.neat.spawn_number_change_rate,
genome_elitism=self.config.neat.genome_elitism,
survival_threshold=self.config.neat.survival_threshold,
compatibility_threshold=self.config.neat.compatibility_threshold,
compatibility_disjoint=self.config.neat.compatibility_disjoint,
compatibility_weight=self.config.neat.compatibility_weight,
input_idx=input_idx,
output_idx=output_idx,
)
state = self.gene_type.setup(state, self.config)
state = self.gene_type.setup(self.config.gene, state)
pop_genomes = self._initialize_genomes(state)
randkey = randkey
pop_nodes, pop_conns = initialize_genomes(state, self.gene_type)
species_info = jnp.full((state.S, 4), jnp.nan,
dtype=jnp.float32) # (species_key, best_fitness, last_improved, size)
species_info = species_info.at[0, :].set([0, -jnp.inf, 0, state.P])
species_keys = np.full((state.S,), np.nan, dtype=np.float32)
best_fitness = np.full((state.S,), np.nan, dtype=np.float32)
last_improved = np.full((state.S,), np.nan, dtype=np.float32)
member_count = np.full((state.S,), np.nan, dtype=np.float32)
idx2species = jnp.zeros(state.P, dtype=jnp.float32)
species_keys[0] = 0
best_fitness[0] = -np.inf
last_improved[0] = 0
member_count[0] = state.P
center_nodes = jnp.full((state.S, state.N, state.NL), jnp.nan, dtype=jnp.float32)
center_conns = jnp.full((state.S, state.C, state.CL), jnp.nan, dtype=jnp.float32)
center_nodes = center_nodes.at[0, :, :].set(pop_nodes[0, :, :])
center_conns = center_conns.at[0, :, :].set(pop_conns[0, :, :])
center_nodes = center_nodes.at[0, :, :].set(pop_genomes.nodes[0, :, :])
center_conns = center_conns.at[0, :, :].set(pop_genomes.conns[0, :, :])
center_genomes = vmap(Genome)(center_nodes, center_conns)
generation = 0
next_node_key = max(*state.input_idx, *state.output_idx) + 2
next_species_key = 1
state = state.update(
randkey=randkey,
pop_nodes=pop_nodes,
pop_conns=pop_conns,
species_info=species_info,
pop_genomes=pop_genomes,
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
member_count=member_count,
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
center_genomes=center_genomes,
# avoid jax auto cast from int to float. that would cause re-compilation.
generation=jnp.asarray(generation, dtype=jnp.int32),
@@ -69,7 +86,112 @@ class NEAT(Algorithm):
next_species_key=jnp.asarray(next_species_key, dtype=jnp.float32),
)
# move to device
state = jax.device_put(state)
self.forward_func = self.gene_type.create_forward(state, self.config.gene)
self.tell_func = self._create_tell()
return state
return jax.device_put(state)
def ask(self, state: State):
"""require the population to be evaluated"""
return state.pop_genomes
def tell(self, state: State, fitness):
"""update the state of the algorithm"""
return self.tell_func(state, fitness)
def forward(self, inputs: Array, transformed: Array):
"""the forward function of a single forward transformation"""
return self.forward_func(inputs, transformed)
def forward_transform(self, state: State, genome: Genome):
"""create the forward transformation of a genome"""
return self.gene_type.forward_transform(state, genome)
def _initialize_genomes(self, state):
o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes
o_conns = np.full((state.C, state.CL), np.nan, dtype=np.float32) # original connections
input_idx = state.input_idx
output_idx = state.output_idx
new_node_key = max([*input_idx, *output_idx]) + 1
o_nodes[input_idx, 0] = input_idx
o_nodes[output_idx, 0] = output_idx
o_nodes[new_node_key, 0] = new_node_key
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = self.gene_type.new_node_attrs(state)
o_nodes[new_node_key, 1:] = self.gene_type.new_node_attrs(state)
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)]
o_conns[input_idx, 0:2] = input_conns # in key, out key
o_conns[input_idx, 2] = True # enabled
o_conns[input_idx, 3:] = self.gene_type.new_conn_attrs(state)
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx]
o_conns[output_idx, 0:2] = output_conns # in key, out key
o_conns[output_idx, 2] = True # enabled
o_conns[output_idx, 3:] = self.gene_type.new_conn_attrs(state)
# repeat origin genome for P times to create population
pop_nodes = np.tile(o_nodes, (state.P, 1, 1))
pop_conns = np.tile(o_conns, (state.P, 1, 1))
return vmap(Genome)(pop_nodes, pop_conns)
def _create_tell(self):
mutate = create_mutate(self.config.neat, self.gene_type)
def create_next_generation(state, randkey, winner, loser, elite_mask):
# prepare random keys
pop_size = state.idx2species.shape[0]
new_node_keys = jnp.arange(pop_size) + state.next_node_key
k1, k2 = jax.random.split(randkey, 2)
crossover_rand_keys = jax.random.split(k1, pop_size)
mutate_rand_keys = jax.random.split(k2, pop_size)
# batch crossover
wpn, wpc = state.pop_genomes.nodes[winner], state.pop_genomes.conns[winner]
lpn, lpc = state.pop_genomes.nodes[loser], state.pop_genomes.conns[loser]
n_genomes = vmap(crossover)(crossover_rand_keys, Genome(wpn, wpc), Genome(lpn, lpc))
# batch mutation
mutate_func = vmap(mutate, in_axes=(None, 0, 0, 0))
m_n_genomes = mutate_func(state, mutate_rand_keys, n_genomes, new_node_keys) # mutate_new_pop_nodes
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_genomes.nodes, m_n_genomes.nodes)
pop_conns = jnp.where(elite_mask[:, None, None], n_genomes.conns, m_n_genomes.conns)
# update next node key
all_nodes_keys = pop_nodes[:, :, 0]
max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys))
next_node_key = max_node_key + 1
return state.update(
pop_genomes=Genome(pop_nodes, pop_conns),
next_node_key=next_node_key,
)
speciate = create_speciate(self.gene_type)
def tell(state, fitness):
"""
Main update function in NEAT.
"""
k1, k2, randkey = jax.random.split(state.randkey, 3)
state = state.update(
generation=state.generation + 1,
randkey=randkey
)
state, winner, loser, elite_mask = update_species(state, k1, fitness)
state = create_next_generation(state, k2, winner, loser, elite_mask)
state = speciate(state)
return state
return tell

View File

@@ -1,363 +0,0 @@
from typing import Type
import jax
from jax import numpy as jnp, vmap
from algorithm.utils import rank_elements, fetch_first
from .genome import create_mutate, create_distance, crossover
from .gene import BaseGene
def create_tell(config, gene_type: Type[BaseGene]):
mutate = create_mutate(config, gene_type)
distance = create_distance(config, gene_type)
def update_species(state, randkey, fitness):
# update the fitness of each species
species_fitness = update_species_fitness(state, fitness)
# stagnation species
state, species_fitness = stagnation(state, species_fitness)
# sort species_info by their fitness. (push nan to the end)
sort_indices = jnp.argsort(species_fitness)[::-1]
state = state.update(
species_info=state.species_info[sort_indices],
center_nodes=state.center_nodes[sort_indices],
center_conns=state.center_conns[sort_indices],
)
# decide the number of members of each species by their fitness
spawn_number = cal_spawn_numbers(state)
# crossover info
winner, loser, elite_mask = create_crossover_pair(state, randkey, spawn_number, fitness)
return state, winner, loser, elite_mask
def update_species_fitness(state, fitness):
"""
obtain the fitness of the species by the fitness of each individual.
use max criterion.
"""
def aux_func(idx):
species_key = state.species_info[idx, 0]
s_fitness = jnp.where(state.idx2species == species_key, fitness, -jnp.inf)
f = jnp.max(s_fitness)
return f
return vmap(aux_func)(jnp.arange(state.species_info.shape[0]))
def stagnation(state, species_fitness):
"""
stagnation species.
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
elitism species never stagnation
"""
def aux_func(idx):
s_fitness = species_fitness[idx]
species_key, best_score, last_update, members_count = state.species_info[idx]
st = (s_fitness <= best_score) & (state.generation - last_update > state.max_stagnation)
last_update = jnp.where(s_fitness > best_score, state.generation, last_update)
best_score = jnp.where(s_fitness > best_score, s_fitness, best_score)
# stagnation condition
return st, jnp.array([species_key, best_score, last_update, members_count])
spe_st, species_info = vmap(aux_func)(jnp.arange(species_fitness.shape[0]))
# elite species will not be stagnation
species_rank = rank_elements(species_fitness)
spe_st = jnp.where(species_rank < state.species_elitism, False, spe_st) # elitism never stagnation
# set stagnation species to nan
species_info = jnp.where(spe_st[:, None], jnp.nan, species_info)
center_nodes = jnp.where(spe_st[:, None, None], jnp.nan, state.center_nodes)
center_conns = jnp.where(spe_st[:, None, None], jnp.nan, state.center_conns)
species_fitness = jnp.where(spe_st, -jnp.inf, species_fitness)
state = state.update(
species_info=species_info,
center_nodes=center_nodes,
center_conns=center_conns,
)
return state, species_fitness
def cal_spawn_numbers(state):
"""
decide the number of members of each species by their fitness rank.
the species with higher fitness will have more members
Linear ranking selection
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
"""
is_species_valid = ~jnp.isnan(state.species_info[:, 0])
valid_species_num = jnp.sum(is_species_valid)
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
rank_score = valid_species_num - jnp.arange(state.species_info.shape[0]) # obtain [3, 2, 1]
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
target_spawn_number = jnp.floor(spawn_number_rate * state.P) # calculate member
# Avoid too much variation of numbers in a species
previous_size = state.species_info[:, 3].astype(jnp.int32)
spawn_number = previous_size + (target_spawn_number - previous_size) * state.spawn_number_change_rate
# jax.debug.print("previous_size: {}, spawn_number: {}", previous_size, spawn_number)
spawn_number = spawn_number.astype(jnp.int32)
# spawn_number = target_spawn_number.astype(jnp.int32)
# must control the sum of spawn_number to be equal to pop_size
error = state.P - jnp.sum(spawn_number)
spawn_number = spawn_number.at[0].add(
error) # add error to the first species to control the sum of spawn_number
return spawn_number
def create_crossover_pair(state, randkey, spawn_number, fitness):
species_size = state.species_info.shape[0]
pop_size = fitness.shape[0]
s_idx = jnp.arange(species_size)
p_idx = jnp.arange(pop_size)
# def aux_func(key, idx):
def aux_func(key, idx):
members = state.idx2species == state.species_info[idx, 0]
members_num = jnp.sum(members)
members_fitness = jnp.where(members, fitness, -jnp.inf)
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
elite_size = state.genome_elitism
survive_size = jnp.floor(state.survival_threshold * members_num).astype(jnp.int32)
select_pro = (p_idx < survive_size) / survive_size
fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro)
# elite
fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa)
ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma)
elite = jnp.where(p_idx < elite_size, True, False)
return fa, ma, elite
fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx)
spawn_number_cum = jnp.cumsum(spawn_number)
def aux_func(idx):
loc = jnp.argmax(idx < spawn_number_cum)
# elite genomes are at the beginning of the species
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species]
part1, part2, elite_mask = vmap(aux_func)(p_idx)
is_part1_win = fitness[part1] >= fitness[part2]
winner = jnp.where(is_part1_win, part1, part2)
loser = jnp.where(is_part1_win, part2, part1)
return winner, loser, elite_mask
def create_next_generation(state, randkey, winner, loser, elite_mask):
# prepare random keys
pop_size = state.pop_nodes.shape[0]
new_node_keys = jnp.arange(pop_size) + state.next_node_key
k1, k2 = jax.random.split(randkey, 2)
crossover_rand_keys = jax.random.split(k1, pop_size)
mutate_rand_keys = jax.random.split(k2, pop_size)
# batch crossover
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner] # winner pop nodes, winner pop connections
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser] # loser pop nodes, loser pop connections
npn, npc = vmap(crossover)(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
# batch mutation
mutate_func = vmap(mutate, in_axes=(None, 0, 0, 0, 0))
m_npn, m_npc = mutate_func(state, mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
pop_conns = jnp.where(elite_mask[:, None, None], npc, m_npc)
# update next node key
all_nodes_keys = pop_nodes[:, :, 0]
max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys))
next_node_key = max_node_key + 1
return state.update(
pop_nodes=pop_nodes,
pop_conns=pop_conns,
next_node_key=next_node_key,
)
def speciate(state):
pop_size, species_size = state.pop_nodes.shape[0], state.center_nodes.shape[0]
# prepare distance functions
o2p_distance_func = vmap(distance, in_axes=(None, None, None, 0, 0)) # one to population
# idx to specie key
idx2specie = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species
# the distance between genomes to its center genomes
o2c_distances = jnp.full((pop_size,), jnp.inf)
# step 1: find new centers
def cond_func(carry):
i, i2s, cn, cc, o2c = carry
species_key = state.species_info[i, 0]
# jax.debug.print("{}, {}", i, species_key)
return (i < species_size) & (~jnp.isnan(species_key)) # current species is existing
def body_func(carry):
i, i2s, cn, cc, o2c = carry
distances = o2p_distance_func(state, cn[i], cc[i], state.pop_nodes, state.pop_conns)
# find the closest one
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
# jax.debug.print("closest_idx: {}", closest_idx)
i2s = i2s.at[closest_idx].set(state.species_info[i, 0])
cn = cn.at[i].set(state.pop_nodes[closest_idx])
cc = cc.at[i].set(state.pop_conns[closest_idx])
# the genome with closest_idx will become the new center, thus its distance to center is 0.
o2c = o2c.at[closest_idx].set(0)
return i + 1, i2s, cn, cc, o2c
_, idx2specie, center_nodes, center_conns, o2c_distances = \
jax.lax.while_loop(cond_func, body_func,
(0, idx2specie, state.center_nodes, state.center_conns, o2c_distances))
# part 2: assign members to each species
def cond_func(carry):
i, i2s, cn, cc, si, o2c, nsk = carry # si is short for species_info, nsk is short for next_species_key
current_species_existed = ~jnp.isnan(si[i, 0])
not_all_assigned = jnp.any(jnp.isnan(i2s))
not_reach_species_upper_bounds = i < species_size
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
def body_func(carry):
i, i2s, cn, cc, si, o2c, nsk = carry # scn is short for spe_center_nodes, scc is short for spe_center_conns
_, i2s, scn, scc, si, o2c, nsk = jax.lax.cond(
jnp.isnan(si[i, 0]), # whether the current species is existing or not
create_new_species, # if not existing, create a new specie
update_exist_specie, # if existing, update the specie
(i, i2s, cn, cc, si, o2c, nsk)
)
return i + 1, i2s, scn, scc, si, o2c, nsk
def create_new_species(carry):
i, i2s, cn, cc, si, o2c, nsk = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(jnp.isnan(i2s))
# assign it to the new species
# [key, best score, last update generation, members_count]
si = si.at[i].set(jnp.array([nsk, -jnp.inf, state.generation, 0]))
i2s = i2s.at[idx].set(nsk)
o2c = o2c.at[idx].set(0)
# update center genomes
cn = cn.at[i].set(state.pop_nodes[idx])
cc = cc.at[i].set(state.pop_conns[idx])
i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c))
# when a new species is created, it needs to be updated, thus do not change i
return i + 1, i2s, cn, cc, si, o2c, nsk + 1 # change to next new speciate key
def update_exist_specie(carry):
i, i2s, cn, cc, si, o2c, nsk = carry
i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c))
# turn to next species
return i + 1, i2s, cn, cc, si, o2c, nsk
def speciate_by_threshold(carry):
i, i2s, cn, cc, si, o2c = carry
# distance between such center genome and ppo genomes
o2p_distance = o2p_distance_func(state, cn[i], cc[i], state.pop_nodes, state.pop_conns)
close_enough_mask = o2p_distance < state.compatibility_threshold
# when a genome is not assigned or the distance between its current center is bigger than this center
cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
# jax.debug.print("{}", o2p_distance)
mask = close_enough_mask & cacheable_mask
# update species info
i2s = jnp.where(mask, si[i, 0], i2s)
# update distance between centers
o2c = jnp.where(mask, o2p_distance, o2c)
return i2s, o2c
# update idx2specie
_, idx2specie, center_nodes, center_conns, species_info, _, next_species_key = jax.lax.while_loop(
cond_func,
body_func,
(0, idx2specie, center_nodes, center_conns, state.species_info, o2c_distances, state.next_species_key)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition can only happen when the number of species is reached species upper bounds
idx2specie = jnp.where(jnp.isnan(idx2specie), species_info[-1, 0], idx2specie)
# update members count
def count_members(idx):
key = species_info[idx, 0]
count = jnp.sum(idx2specie == key)
count = jnp.where(jnp.isnan(key), jnp.nan, count)
return count
species_member_counts = vmap(count_members)(jnp.arange(species_size))
species_info = species_info.at[:, 3].set(species_member_counts)
return state.update(
idx2species=idx2specie,
center_nodes=center_nodes,
center_conns=center_conns,
species_info=species_info,
next_species_key=next_species_key
)
def tell(state, fitness):
"""
Main update function in NEAT.
"""
k1, k2, randkey = jax.random.split(state.randkey, 3)
state = state.update(
generation=state.generation + 1,
randkey=randkey
)
state, winner, loser, elite_mask = update_species(state, k1, fitness)
state = create_next_generation(state, k2, winner, loser, elite_mask)
state = speciate(state)
return state
return tell
def argmin_with_mask(arr, mask):
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx

View File

@@ -0,0 +1 @@
from .operations import update_species, create_speciate

View File

@@ -1,11 +1,11 @@
from typing import Dict, Type
from typing import Type
from jax import Array, numpy as jnp, vmap
from ..gene import BaseGene
from core import Gene
def create_distance(config: Dict, gene_type: Type[BaseGene]):
def create_distance(gene_type: Type[Gene]):
def node_distance(state, nodes1: Array, nodes2: Array):
"""
Calculate the distance between nodes of two genomes.
@@ -35,8 +35,7 @@ def create_distance(config: Dict, gene_type: Type[BaseGene]):
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
val = non_homologous_cnt * config['compatibility_disjoint'] + homologous_distance * config[
'compatibility_weight']
val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
@@ -64,13 +63,11 @@ def create_distance(config: Dict, gene_type: Type[BaseGene]):
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask)
val = non_homologous_cnt * config['compatibility_disjoint'] + homologous_distance * config[
'compatibility_weight']
val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight
return jnp.where(max_cnt == 0, 0, val / max_cnt)
def distance(state, nodes1, conns1, nodes2, conns2):
return node_distance(state, nodes1, nodes2) + connection_distance(state, conns1, conns2)
def distance(state, genome1, genome2):
return node_distance(state, genome1.nodes, genome2.nodes) + connection_distance(state, genome1.conns, genome2.conns)
return distance

View File

@@ -0,0 +1,334 @@
from typing import Type
import jax
from jax import numpy as jnp, vmap
from core import Gene, Genome
from utils import rank_elements, fetch_first
from .distance import create_distance
def update_species(state, randkey, fitness):
# update the fitness of each species
species_fitness = update_species_fitness(state, fitness)
# stagnation species
state, species_fitness = stagnation(state, species_fitness)
# sort species_info by their fitness. (push nan to the end)
sort_indices = jnp.argsort(species_fitness)[::-1]
center_nodes = state.center_genomes.nodes[sort_indices]
center_conns = state.center_genomes.conns[sort_indices]
state = state.update(
species_keys=state.species_keys[sort_indices],
best_fitness=state.best_fitness[sort_indices],
last_improved=state.last_improved[sort_indices],
member_count=state.member_count[sort_indices],
center_genomes=Genome(center_nodes, center_conns),
)
# decide the number of members of each species by their fitness
spawn_number = cal_spawn_numbers(state)
# crossover info
winner, loser, elite_mask = create_crossover_pair(state, randkey, spawn_number, fitness)
return state, winner, loser, elite_mask
def update_species_fitness(state, fitness):
"""
obtain the fitness of the species by the fitness of each individual.
use max criterion.
"""
def aux_func(idx):
s_fitness = jnp.where(state.idx2species == state.species_keys[idx], fitness, -jnp.inf)
f = jnp.max(s_fitness)
return f
return vmap(aux_func)(jnp.arange(state.species_keys.shape[0]))
def stagnation(state, species_fitness):
"""
stagnation species.
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
elitism species never stagnation
"""
def aux_func(idx):
s_fitness = species_fitness[idx]
sk, bf, li = state.species_keys[idx], state.best_fitness[idx], state.last_improved[idx]
st = (s_fitness <= bf) & (state.generation - li > state.max_stagnation)
li = jnp.where(s_fitness > bf, state.generation, li)
bf = jnp.where(s_fitness > bf, s_fitness, bf)
return st, sk, bf, li
spe_st, species_keys, best_fitness, last_improved = vmap(aux_func)(jnp.arange(species_fitness.shape[0]))
# elite species will not be stagnation
species_rank = rank_elements(species_fitness)
spe_st = jnp.where(species_rank < state.species_elitism, False, spe_st) # elitism never stagnation
# set stagnation species to nan
species_keys = jnp.where(spe_st, jnp.nan, species_keys)
best_fitness = jnp.where(spe_st, jnp.nan, best_fitness)
last_improved = jnp.where(spe_st, jnp.nan, last_improved)
member_count = jnp.where(spe_st, jnp.nan, state.member_count)
species_fitness = jnp.where(spe_st, -jnp.inf, species_fitness)
center_nodes = jnp.where(spe_st[:, None, None], jnp.nan, state.center_genomes.nodes)
center_conns = jnp.where(spe_st[:, None, None], jnp.nan, state.center_genomes.conns)
state = state.update(
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
member_count=member_count,
center_genomes=state.center_genomes.update(center_nodes, center_conns)
)
return state, species_fitness
def cal_spawn_numbers(state):
"""
decide the number of members of each species by their fitness rank.
the species with higher fitness will have more members
Linear ranking selection
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
"""
is_species_valid = ~jnp.isnan(state.species_keys)
valid_species_num = jnp.sum(is_species_valid)
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
rank_score = valid_species_num - jnp.arange(state.species_keys.shape[0]) # obtain [3, 2, 1]
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
target_spawn_number = jnp.floor(spawn_number_rate * state.P) # calculate member
# Avoid too much variation of numbers in a species
previous_size = state.member_count
spawn_number = previous_size + (target_spawn_number - previous_size) * state.spawn_number_change_rate
# jax.debug.print("previous_size: {}, spawn_number: {}", previous_size, spawn_number)
spawn_number = spawn_number.astype(jnp.int32)
# must control the sum of spawn_number to be equal to pop_size
error = state.P - jnp.sum(spawn_number)
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
return spawn_number
def create_crossover_pair(state, randkey, spawn_number, fitness):
species_size = state.species_keys.shape[0]
pop_size = fitness.shape[0]
s_idx = jnp.arange(species_size)
p_idx = jnp.arange(pop_size)
# def aux_func(key, idx):
def aux_func(key, idx):
members = state.idx2species == state.species_keys[idx]
members_num = jnp.sum(members)
members_fitness = jnp.where(members, fitness, -jnp.inf)
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
elite_size = state.genome_elitism
survive_size = jnp.floor(state.survival_threshold * members_num).astype(jnp.int32)
select_pro = (p_idx < survive_size) / survive_size
fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro)
# elite
fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa)
ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma)
elite = jnp.where(p_idx < elite_size, True, False)
return fa, ma, elite
fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx)
spawn_number_cum = jnp.cumsum(spawn_number)
def aux_func(idx):
loc = jnp.argmax(idx < spawn_number_cum)
# elite genomes are at the beginning of the species
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species]
part1, part2, elite_mask = vmap(aux_func)(p_idx)
is_part1_win = fitness[part1] >= fitness[part2]
winner = jnp.where(is_part1_win, part1, part2)
loser = jnp.where(is_part1_win, part2, part1)
return winner, loser, elite_mask
def create_speciate(gene_type: Type[Gene]):
distance = create_distance(gene_type)
def speciate(state):
pop_size, species_size = state.idx2species.shape[0], state.species_keys.shape[0]
# prepare distance functions
o2p_distance_func = vmap(distance, in_axes=(None, None, 0)) # one to population
# idx to specie key
idx2species = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species
# the distance between genomes to its center genomes
o2c_distances = jnp.full((pop_size,), jnp.inf)
# step 1: find new centers
def cond_func(carry):
i, i2s, cgs, o2c = carry
return (i < species_size) & (~jnp.isnan(state.species_keys[i])) # current species is existing
def body_func(carry):
i, i2s, cgs, o2c = carry
distances = o2p_distance_func(state, Genome(cgs.nodes[i], cgs.conns[i]), state.pop_genomes)
# find the closest one
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
# jax.debug.print("closest_idx: {}", closest_idx)
i2s = i2s.at[closest_idx].set(state.species_keys[i])
cn = cgs.nodes.at[i].set(state.pop_genomes.nodes[closest_idx])
cc = cgs.conns.at[i].set(state.pop_genomes.conns[closest_idx])
# the genome with closest_idx will become the new center, thus its distance to center is 0.
o2c = o2c.at[closest_idx].set(0)
return i + 1, i2s, Genome(cn, cc), o2c
_, idx2species, center_genomes, o2c_distances = \
jax.lax.while_loop(cond_func, body_func, (0, idx2species, state.center_genomes, o2c_distances))
state = state.update(
idx2species=idx2species,
center_genomes=center_genomes,
)
# part 2: assign members to each species
def cond_func(carry):
i, i2s, cgs, sk, o2c, nsk = carry
current_species_existed = ~jnp.isnan(sk[i])
not_all_assigned = jnp.any(jnp.isnan(i2s))
not_reach_species_upper_bounds = i < species_size
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
def body_func(carry):
i, i2s, cgs, sk, o2c, nsk = carry
_, i2s, cgs, sk, o2c, nsk = jax.lax.cond(
jnp.isnan(sk[i]), # whether the current species is existing or not
create_new_species, # if not existing, create a new specie
update_exist_specie, # if existing, update the specie
(i, i2s, cgs, sk, o2c, nsk)
)
return i + 1, i2s, cgs, sk, o2c, nsk
def create_new_species(carry):
i, i2s, cgs, sk, o2c, nsk = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(jnp.isnan(i2s))
# assign it to the new species
# [key, best score, last update generation, members_count]
sk = sk.at[i].set(nsk)
i2s = i2s.at[idx].set(nsk)
o2c = o2c.at[idx].set(0)
# update center genomes
cn = cgs.nodes.at[i].set(state.pop_genomes.nodes[idx])
cc = cgs.conns.at[i].set(state.pop_genomes.conns[idx])
cgs = Genome(cn, cc)
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
# when a new species is created, it needs to be updated, thus do not change i
return i + 1, i2s, cgs, sk, o2c, nsk + 1 # change to next new speciate key
def update_exist_specie(carry):
i, i2s, cgs, sk, o2c, nsk = carry
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
# turn to next species
return i + 1, i2s, cgs, sk, o2c, nsk
def speciate_by_threshold(i, i2s, cgs, sk, o2c):
# distance between such center genome and ppo genomes
center = Genome(cgs.nodes[i], cgs.conns[i])
o2p_distance = o2p_distance_func(state, center, state.pop_genomes)
close_enough_mask = o2p_distance < state.compatibility_threshold
# when a genome is not assigned or the distance between its current center is bigger than this center
cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
# jax.debug.print("{}", o2p_distance)
mask = close_enough_mask & cacheable_mask
# update species info
i2s = jnp.where(mask, sk[i], i2s)
# update distance between centers
o2c = jnp.where(mask, o2p_distance, o2c)
return i2s, o2c
# update idx2species
_, idx2species, center_genomes, species_keys, _, next_species_key = jax.lax.while_loop(
cond_func,
body_func,
(0, state.idx2species, state.center_genomes, state.species_keys, o2c_distances, state.next_species_key)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition can only happen when the number of species is reached species upper bounds
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
# complete info of species which is created in this generation
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.best_fitness)
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.best_fitness)
last_improved = jnp.where(new_created_mask, state.generation, state.last_improved)
# update members count
def count_members(idx):
key = species_keys[idx]
count = jnp.sum(idx2species == key)
count = jnp.where(jnp.isnan(key), jnp.nan, count)
return count
member_count = vmap(count_members)(jnp.arange(species_size))
return state.update(
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
members_count=member_count,
idx2species=idx2species,
center_genomes=center_genomes,
next_species_key=next_species_key
)
return speciate
def argmin_with_mask(arr, mask):
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx

View File

@@ -1 +1,2 @@
from .config import Configer
from .config import *

View File

@@ -1,70 +1,103 @@
import os
import warnings
import configparser
import numpy as np
from dataclasses import dataclass
from typing import Union
class Configer:
@dataclass(frozen=True)
class BasicConfig:
seed: int = 42
fitness_target: float = 1
generation_limit: int = 1000
num_inputs: int = 2
num_outputs: int = 1
pop_size: int = 100
@classmethod
def __load_default_config(cls):
par_dir = os.path.dirname(os.path.abspath(__file__))
default_config_path = os.path.join(par_dir, "default_config.ini")
return cls.__load_config(default_config_path)
def __post_init__(self):
assert self.num_inputs > 0, "the inputs number of the problem must be greater than 0"
assert self.num_outputs > 0, "the outputs number of the problem must be greater than 0"
assert self.pop_size > 0, "the population size must be greater than 0"
@classmethod
def __load_config(cls, config_path):
c = configparser.ConfigParser()
c.read(config_path)
config = {}
for section in c.sections():
for key, value in c.items(section):
config[key] = eval(value)
@dataclass(frozen=True)
class NeatConfig:
network_type: str = "feedforward"
activate_times: Union[int, None] = None # None means the network is feedforward
maximum_nodes: int = 100
maximum_conns: int = 50
maximum_species: int = 10
return config
# genome config
compatibility_disjoint: float = 1
compatibility_weight: float = 0.5
conn_add: float = 0.4
conn_delete: float = 0.4
node_add: float = 0.2
node_delete: float = 0.2
@classmethod
def __check_redundant_config(cls, default_config, config):
for key in config:
if key not in default_config:
warnings.warn(f"Redundant config: {key} in config!")
# species config
compatibility_threshold: float = 3.0
species_elitism: int = 2
max_stagnation: int = 15
genome_elitism: int = 2
survival_threshold: float = 0.2
min_species_size: int = 1
spawn_number_change_rate: float = 0.5
@classmethod
def __complete_config(cls, default_config, config):
for key in default_config:
if key not in config:
config[key] = default_config[key]
@classmethod
def load_config(cls, config_path=None):
default_config = cls.__load_default_config()
if config_path is None:
config = {}
elif not os.path.exists(config_path):
warnings.warn(f"config file {config_path} not exist!")
config = {}
def __post_init__(self):
assert self.network_type in ["feedforward", "recurrent"], "the network type must be feedforward or recurrent"
if self.network_type == "feedforward":
assert self.activate_times is None, "the activate times of feedforward network must be None"
else:
config = cls.__load_config(config_path)
assert isinstance(self.activate_times, int), "the activate times of recurrent network must be int"
assert self.activate_times > 0, "the activate times of recurrent network must be greater than 0"
cls.__check_redundant_config(default_config, config)
cls.__complete_config(default_config, config)
assert self.maximum_nodes > 0, "the maximum nodes must be greater than 0"
assert self.maximum_conns > 0, "the maximum connections must be greater than 0"
assert self.maximum_species > 0, "the maximum species must be greater than 0"
cls.refactor_activation(config)
cls.refactor_aggregation(config)
assert self.compatibility_disjoint > 0, "the compatibility disjoint must be greater than 0"
assert self.compatibility_weight > 0, "the compatibility weight must be greater than 0"
assert self.conn_add > 0, "the connection add probability must be greater than 0"
assert self.conn_delete > 0, "the connection delete probability must be greater than 0"
assert self.node_add > 0, "the node add probability must be greater than 0"
assert self.node_delete > 0, "the node delete probability must be greater than 0"
config['input_idx'] = np.arange(config['num_inputs'])
config['output_idx'] = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs'])
assert self.compatibility_threshold > 0, "the compatibility threshold must be greater than 0"
assert self.species_elitism > 0, "the species elitism must be greater than 0"
assert self.max_stagnation > 0, "the max stagnation must be greater than 0"
assert self.genome_elitism > 0, "the genome elitism must be greater than 0"
assert self.survival_threshold > 0, "the survival threshold must be greater than 0"
assert self.min_species_size > 0, "the min species size must be greater than 0"
assert self.spawn_number_change_rate > 0, "the spawn number change rate must be greater than 0"
return config
@classmethod
def refactor_activation(cls, config):
config['activation_default'] = 0
config['activation_options'] = np.arange(len(config['activation_option_names']))
@dataclass(frozen=True)
class HyperNeatConfig:
below_threshold: float = 0.2
max_weight: float = 3
activation: str = "sigmoid"
aggregation: str = "sum"
activate_times: int = 5
@classmethod
def refactor_aggregation(cls, config):
config['aggregation_default'] = 0
config['aggregation_options'] = np.arange(len(config['aggregation_option_names']))
def __post_init__(self):
assert self.below_threshold > 0, "the below threshold must be greater than 0"
assert self.max_weight > 0, "the max weight must be greater than 0"
assert self.activate_times > 0, "the activate times must be greater than 0"
@dataclass(frozen=True)
class GeneConfig:
pass
@dataclass(frozen=True)
class SubstrateConfig:
pass
@dataclass(frozen=True)
class Config:
basic: BasicConfig = BasicConfig()
neat: NeatConfig = NeatConfig()
hyper_neat: HyperNeatConfig = HyperNeatConfig()
gene: GeneConfig = GeneConfig()
substrate: SubstrateConfig = SubstrateConfig()

View File

@@ -1,8 +1,6 @@
[basic]
random_seed = 0
generation_limit = 1000
[problem]
fitness_threshold = 3.9999
num_inputs = 2
num_outputs = 1
@@ -14,6 +12,13 @@ maximum_nodes = 50
maximum_conns = 50
maximum_species = 10
compatibility_disjoint = 1.0
compatibility_weight = 0.5
conn_add_prob = 0.4
conn_delete_prob = 0
node_add_prob = 0.2
node_delete_prob = 0
[hyperneat]
below_threshold = 0.2
max_weight = 3
@@ -26,17 +31,6 @@ input_coors = [[-1, 1], [0, 1], [1, 1]]
hidden_coors = [[-1, 0], [0, 0], [1, 0]]
output_coors = [[0, -1]]
[population]
pop_size = 10
[genome]
compatibility_disjoint = 1.0
compatibility_weight = 0.5
conn_add_prob = 0.4
conn_delete_prob = 0
node_add_prob = 0.2
node_delete_prob = 0
[species]
compatibility_threshold = 3.0
species_elitism = 2

5
core/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
from .algorithm import Algorithm
from .state import State
from .genome import Genome
from .gene import Gene

28
core/algorithm.py Normal file
View File

@@ -0,0 +1,28 @@
from jax import Array
from .state import State
from .genome import Genome
EMPTY = lambda *args: args
class Algorithm:
def setup(self, randkey, state: State = State()):
"""initialize the state of the algorithm"""
pass
def ask(self, state: State):
"""require the population to be evaluated"""
pass
def tell(self, state: State, fitness):
"""update the state of the algorithm"""
pass
def forward(self, inputs: Array, transformed: Array):
"""the forward function of a single forward transformation"""
pass
def forward_transform(self, state: State, genome: Genome):
"""create the forward transformation of a genome"""
pass

46
core/gene.py Normal file
View File

@@ -0,0 +1,46 @@
from jax import Array, numpy as jnp
from config import GeneConfig
from .state import State
from .genome import Genome
class Gene:
node_attrs = []
conn_attrs = []
@staticmethod
def setup(config: GeneConfig, state: State):
return state
@staticmethod
def new_node_attrs(state: State):
return jnp.zeros(0)
@staticmethod
def new_conn_attrs(state: State):
return jnp.zeros(0)
@staticmethod
def mutate_node(state: State, attrs: Array, randkey: Array):
return attrs
@staticmethod
def mutate_conn(state: State, attrs: Array, randkey: Array):
return attrs
@staticmethod
def distance_node(state: State, node1: Array, node2: Array):
return node1
@staticmethod
def distance_conn(state: State, conn1: Array, conn2: Array):
return conn1
@staticmethod
def forward_transform(state: State, genome: Genome):
return jnp.zeros(0) # transformed
@staticmethod
def create_forward(state: State, config: GeneConfig):
return lambda *args: args # forward function

77
core/genome.py Normal file
View File

@@ -0,0 +1,77 @@
from jax.tree_util import register_pytree_node_class
from jax import numpy as jnp
from utils.tools import fetch_first
@register_pytree_node_class
class Genome:
def __init__(self, nodes, conns):
self.nodes = nodes
self.conns = conns
def update(self, nodes, conns):
return self.__class__(nodes, conns)
def update_nodes(self, nodes):
return self.update(nodes, self.conns)
def update_conns(self, conns):
return self.update(self.nodes, conns)
def count(self):
"""Count how many nodes and connections are in the genome."""
nodes_cnt = jnp.sum(~jnp.isnan(self.nodes[:, 0]))
conns_cnt = jnp.sum(~jnp.isnan(self.conns[:, 0]))
return nodes_cnt, conns_cnt
def add_node(self, new_key: int, attrs):
"""
Add a new node to the genome.
The new node will place at the first NaN row.
"""
exist_keys = self.nodes[:, 0]
pos = fetch_first(jnp.isnan(exist_keys))
new_nodes = self.nodes.at[pos, 0].set(new_key)
new_nodes = new_nodes.at[pos, 1:].set(attrs)
return self.update_nodes(new_nodes)
def delete_node_by_pos(self, pos):
"""
Delete a node from the genome.
Delete the node by its pos in nodes.
"""
nodes = self.nodes.at[pos].set(jnp.nan)
return self.update_nodes(nodes)
def add_conn(self, i_key, o_key, enable: bool, attrs):
"""
Add a new connection to the genome.
The new connection will place at the first NaN row.
"""
con_keys = self.conns[:, 0]
pos = fetch_first(jnp.isnan(con_keys))
new_conns = self.conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable]))
new_conns = new_conns.at[pos, 3:].set(attrs)
return self.update_conns(new_conns)
def delete_conn_by_pos(self, pos):
"""
Delete a connection from the genome.
Delete the connection by its idx.
"""
conns = self.conns.at[pos].set(jnp.nan)
return self.update_conns(conns)
def tree_flatten(self):
children = self.nodes, self.conns
aux_data = None
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
def __repr__(self):
return f"Genome(nodes={self.nodes}, conns={self.conns})"

View File

@@ -1,11 +1,28 @@
import numpy as np
import jax.numpy as jnp
import jax
from jax import numpy as jnp
a = jnp.zeros((5, 5))
k1 = jnp.array([1, 2, 3])
k2 = jnp.array([2, 3, 4])
v = jnp.array([1, 1, 1])
from config import Config
from core import Genome
a = a.at[k1, k2].set(v)
config = Config()
from dataclasses import asdict
print(asdict(config))
pop_nodes = jnp.ones((Config.basic.pop_size, Config.neat.maximum_nodes, 3))
pop_conns = jnp.ones((Config.basic.pop_size, Config.neat.maximum_conns, 5))
pop_genomes1 = jax.vmap(Genome)(pop_nodes, pop_conns)
pop_genomes2 = Genome(pop_nodes, pop_conns)
print(pop_genomes)
print(pop_genomes[0])
@jax.vmap
def pop_cnts(genome):
return genome.count()
cnts = pop_cnts(pop_genomes)
print(cnts)
print(a)

19
examples/b.py Normal file
View File

@@ -0,0 +1,19 @@
from enum import Enum
from jax import jit
class NetworkType(Enum):
ANN = 0
SNN = 1
LSTM = 2
@jit
def func(d):
return d[0] + 1
d = {0: 1, 1: NetworkType.ANN.value}
print(func(d))

View File

@@ -1,44 +0,0 @@
import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
@register_pytree_node_class
class Genome:
def __init__(self, nodes, conns):
self.nodes = nodes
self.conns = conns
def update_nodes(self, nodes):
return Genome(nodes, self.conns)
def update_conns(self, conns):
return Genome(self.nodes, conns)
def tree_flatten(self):
children = self.nodes, self.conns
aux_data = None
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
def __repr__(self):
return f"Genome ({self.nodes}, \n\t{self.conns})"
@jax.jit
def add_node(self, a: int):
nodes = self.nodes.at[0, :].set(a)
return self.update_nodes(nodes)
nodes, conns = jnp.array([[1, 2, 3, 4, 5]]), jnp.array([[1, 2, 3, 4]])
g = Genome(nodes, conns)
print(g)
g = g.add_node(1)
print(g)
g = jax.jit(g.add_node)(2)
print(g)

View File

@@ -1,12 +0,0 @@
[basic]
activate_times = 5
fitness_threshold = 4
[population]
pop_size = 1000
[neat]
network_type = "recurrent"
num_inputs = 4
num_outputs = 1

View File

@@ -1,10 +1,10 @@
import jax
import numpy as np
from config import Config, BasicConfig
from pipeline import Pipeline
from config import Configer
from algorithm import NEAT
from algorithm.neat import RecurrentGene
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from algorithm.neat.neat import NEAT
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
@@ -21,13 +21,11 @@ def evaluate(forward_func):
return fitnesses
def main():
config = Configer.load_config("xor.ini")
algorithm = NEAT(config, RecurrentGene)
pipeline = Pipeline(config, algorithm)
best = pipeline.auto_run(evaluate)
print(best)
if __name__ == '__main__':
main()
config = Config(
basic=BasicConfig(fitness_target=4),
gene=NormalGeneConfig()
)
algorithm = NEAT(config, NormalGene)
pipeline = Pipeline(config, algorithm)
pipeline.auto_run(evaluate)

View File

@@ -1,33 +0,0 @@
import jax
import numpy as np
from pipeline import Pipeline
from config import Configer
from algorithm import NEAT, HyperNEAT
from algorithm.neat import RecurrentGene
from algorithm.hyperneat import BaseSubstrate
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
def main():
config = Configer.load_config("xor.ini")
algorithm = HyperNEAT(config, RecurrentGene, BaseSubstrate)
pipeline = Pipeline(config, algorithm)
pipeline.auto_run(evaluate)
if __name__ == '__main__':
main()

View File

@@ -5,7 +5,8 @@ import jax
from jax import vmap, jit
import numpy as np
from algorithm import Algorithm
from config import Config
from core import Algorithm, Genome
class Pipeline:
@@ -13,11 +14,11 @@ class Pipeline:
Neat algorithm pipeline.
"""
def __init__(self, config, algorithm: Algorithm):
def __init__(self, config: Config, algorithm: Algorithm):
self.config = config
self.algorithm = algorithm
randkey = jax.random.PRNGKey(config['random_seed'])
randkey = jax.random.PRNGKey(config.basic.seed)
self.state = algorithm.setup(randkey)
self.best_genome = None
@@ -29,18 +30,18 @@ class Pipeline:
self.forward_func = jit(self.algorithm.forward)
self.batch_forward_func = jit(vmap(self.forward_func, in_axes=(0, None)))
self.pop_batch_forward_func = jit(vmap(self.batch_forward_func, in_axes=(None, 0)))
self.forward_transform_func = jit(vmap(self.algorithm.forward_transform, in_axes=(None, 0, 0)))
self.forward_transform_func = jit(vmap(self.algorithm.forward_transform, in_axes=(None, 0)))
self.tell_func = jit(self.algorithm.tell)
def ask(self):
pop_transforms = self.forward_transform_func(self.state, self.state.pop_nodes, self.state.pop_conns)
pop_transforms = self.forward_transform_func(self.state, self.state.pop_genomes)
return lambda inputs: self.pop_batch_forward_func(inputs, pop_transforms)
def tell(self, fitness):
self.state = self.tell_func(self.state, fitness)
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config['generation_limit']):
for _ in range(self.config.basic.generation_limit):
forward_func = self.ask()
fitnesses = fitness_func(forward_func)
@@ -52,7 +53,7 @@ class Pipeline:
assert callable(analysis), f"What the fuck you passed in? A {analysis}?"
analysis(fitnesses)
if max(fitnesses) >= self.config['fitness_threshold']:
if max(fitnesses) >= self.config.basic.fitness_target:
print("Fitness limit reached!")
return self.best_genome
@@ -70,9 +71,9 @@ class Pipeline:
max_idx = np.argmax(fitnesses)
if fitnesses[max_idx] > self.best_fitness:
self.best_fitness = fitnesses[max_idx]
self.best_genome = (self.state.pop_nodes[max_idx], self.state.pop_conns[max_idx])
self.best_genome = Genome(self.state.pop_genomes.nodes[max_idx], self.state.pop_genomes.conns[max_idx])
member_count = jax.device_get(self.state.species_info[:, 3])
member_count = jax.device_get(self.state.member_count)
species_sizes = [int(i) for i in member_count if i > 0]
print(f"Generation: {self.state.generation}",

View File

View File

View File

@@ -1,56 +0,0 @@
import numpy as np
from algorithm.hyperneat.substrate.tools import cartesian_product
def test01():
keys1 = np.array([1, 2, 3])
keys2 = np.array([4, 5, 6, 7])
coors1 = np.array([
[1, 1, 1],
[2, 2, 2],
[3, 3, 3]
])
coors2 = np.array([
[4, 4, 4],
[5, 5, 5],
[6, 6, 6],
[7, 7, 7]
])
target_coors = np.array([
[1, 1, 1, 4, 4, 4],
[1, 1, 1, 5, 5, 5],
[1, 1, 1, 6, 6, 6],
[1, 1, 1, 7, 7, 7],
[2, 2, 2, 4, 4, 4],
[2, 2, 2, 5, 5, 5],
[2, 2, 2, 6, 6, 6],
[2, 2, 2, 7, 7, 7],
[3, 3, 3, 4, 4, 4],
[3, 3, 3, 5, 5, 5],
[3, 3, 3, 6, 6, 6],
[3, 3, 3, 7, 7, 7]
])
target_keys = np.array([
[1, 4],
[1, 5],
[1, 6],
[1, 7],
[2, 4],
[2, 5],
[2, 6],
[2, 7],
[3, 4],
[3, 5],
[3, 6],
[3, 7]
])
new_coors, correspond_keys = cartesian_product(keys1, keys2, coors1, coors2)
assert np.array_equal(new_coors, target_coors)
assert np.array_equal(correspond_keys, target_keys)

View File

@@ -1,32 +0,0 @@
import jax.numpy as jnp
from algorithm.neat.genome.graph import topological_sort, check_cycles
from algorithm.utils import I_INT
nodes = jnp.array([
[0],
[1],
[2],
[3],
[jnp.nan]
])
# {(0, 2), (1, 2), (1, 3), (2, 3)}
conns = jnp.array([
[0, 0, 1, 0, 0],
[0, 0, 1, 1, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]
])
def test_topological_sort():
assert jnp.all(topological_sort(nodes, conns) == jnp.array([0, 1, 2, 3, I_INT]))
def test_check_cycles():
assert check_cycles(nodes, conns, 3, 2)
assert ~check_cycles(nodes, conns, 2, 3)
assert ~check_cycles(nodes, conns, 0, 3)
assert ~check_cycles(nodes, conns, 1, 0)

View File

@@ -1,33 +0,0 @@
import jax.numpy as jnp
from algorithm.utils import unflatten_connections
def test_unflatten():
nodes = jnp.array([
[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan]
])
conns = jnp.array([
[0, 1, True, 0.1, 0.11],
[0, 2, False, 0.2, 0.22],
[1, 2, True, 0.3, 0.33],
[1, 3, False, 0.4, 0.44],
])
res = unflatten_connections(nodes, conns)
assert jnp.all(res[:, 0, 1] == jnp.array([True, 0.1, 0.11]))
assert jnp.all(res[:, 0, 2] == jnp.array([False, 0.2, 0.22]))
assert jnp.all(res[:, 1, 2] == jnp.array([True, 0.3, 0.33]))
assert jnp.all(res[:, 1, 3] == jnp.array([False, 0.4, 0.44]))
# Create a mask that excludes the indices we've already checked
mask = jnp.ones(res.shape, dtype=bool)
mask = mask.at[:, [0, 0, 1, 1], [1, 2, 2, 3]].set(False)
# Ensure all other places are jnp.nan
assert jnp.all(jnp.isnan(res[mask]))

4
utils/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
from .activation import Activation
from .aggregation import Aggregation
from .tools import *
from .graph import *

View File

@@ -88,6 +88,7 @@ class Activation:
def cube_act(z):
return z ** 3
Activation.name2func = {
'sigmoid': Activation.sigmoid_act,
'tanh': Activation.tanh_act,

View File

@@ -6,13 +6,14 @@ Only used in feed-forward networks.
import jax
from jax import jit, Array, numpy as jnp
from algorithm.utils import fetch_first, I_INT
from .tools import fetch_first, I_INT
@jit
def topological_sort(nodes: Array, conns: Array) -> Array:
"""
a jit-able version of topological_sort!
conns: Array[N, N]
"""
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(conns, axis=0))

View File

@@ -9,12 +9,9 @@ EMPTY_NODE = np.full((1, 5), jnp.nan)
EMPTY_CON = np.full((1, 4), jnp.nan)
@jit
def unflatten_connections(nodes: Array, conns: Array):
def unflatten_conns(nodes, conns):
"""
transform the (C, CL) connections to (CL-2, N, N)
:param nodes: (N, NL)
:param cons: (C, CL)
:return:
"""
N = nodes.shape[0]