change a lot
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
from .state import State
|
||||
from .neat import NEAT
|
||||
|
||||
118
algorithm/config.py
Normal file
118
algorithm/config.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import os
|
||||
import warnings
|
||||
import configparser
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
|
||||
jit_config_keys = [
|
||||
"input_idx",
|
||||
"output_idx",
|
||||
"compatibility_disjoint",
|
||||
"compatibility_weight",
|
||||
"conn_add_prob",
|
||||
"conn_add_trials",
|
||||
"conn_delete_prob",
|
||||
"node_add_prob",
|
||||
"node_delete_prob",
|
||||
"compatibility_threshold",
|
||||
"bias_init_mean",
|
||||
"bias_init_std",
|
||||
"bias_mutate_power",
|
||||
"bias_mutate_rate",
|
||||
"bias_replace_rate",
|
||||
"response_init_mean",
|
||||
"response_init_std",
|
||||
"response_mutate_power",
|
||||
"response_mutate_rate",
|
||||
"response_replace_rate",
|
||||
"activation_default",
|
||||
"activation_options",
|
||||
"activation_replace_rate",
|
||||
"aggregation_default",
|
||||
"aggregation_options",
|
||||
"aggregation_replace_rate",
|
||||
"weight_init_mean",
|
||||
"weight_init_std",
|
||||
"weight_mutate_power",
|
||||
"weight_mutate_rate",
|
||||
"weight_replace_rate",
|
||||
"enable_mutate_rate",
|
||||
"max_stagnation",
|
||||
"pop_size",
|
||||
"genome_elitism",
|
||||
"survival_threshold",
|
||||
"species_elitism",
|
||||
"spawn_number_move_rate"
|
||||
]
|
||||
|
||||
|
||||
class Configer:
|
||||
|
||||
@classmethod
|
||||
def __load_default_config(cls):
|
||||
par_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
default_config_path = os.path.join(par_dir, "default_config.ini")
|
||||
return cls.__load_config(default_config_path)
|
||||
|
||||
@classmethod
|
||||
def __load_config(cls, config_path):
|
||||
c = configparser.ConfigParser()
|
||||
c.read(config_path)
|
||||
config = {}
|
||||
|
||||
for section in c.sections():
|
||||
for key, value in c.items(section):
|
||||
config[key] = eval(value)
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def __check_redundant_config(cls, default_config, config):
|
||||
for key in config:
|
||||
if key not in default_config:
|
||||
warnings.warn(f"Redundant config: {key} in {config.name}")
|
||||
|
||||
@classmethod
|
||||
def __complete_config(cls, default_config, config):
|
||||
for key in default_config:
|
||||
if key not in config:
|
||||
config[key] = default_config[key]
|
||||
|
||||
@classmethod
|
||||
def load_config(cls, config_path=None):
|
||||
default_config = cls.__load_default_config()
|
||||
if config_path is None:
|
||||
config = {}
|
||||
elif not os.path.exists(config_path):
|
||||
warnings.warn(f"config file {config_path} not exist!")
|
||||
config = {}
|
||||
else:
|
||||
config = cls.__load_config(config_path)
|
||||
|
||||
cls.__check_redundant_config(default_config, config)
|
||||
cls.__complete_config(default_config, config)
|
||||
|
||||
cls.refactor_activation(config)
|
||||
cls.refactor_aggregation(config)
|
||||
|
||||
config['input_idx'] = np.arange(config['num_inputs'])
|
||||
config['output_idx'] = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs'])
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def refactor_activation(cls, config):
|
||||
config['activation_default'] = 0
|
||||
config['activation_options'] = np.arange(len(config['activation_option_names']))
|
||||
|
||||
@classmethod
|
||||
def refactor_aggregation(cls, config):
|
||||
config['aggregation_default'] = 0
|
||||
config['aggregation_options'] = np.arange(len(config['aggregation_option_names']))
|
||||
|
||||
@classmethod
|
||||
def create_jit_config(cls, config):
|
||||
jit_config = {k: config[k] for k in jit_config_keys}
|
||||
|
||||
return jit_config
|
||||
74
algorithm/default_config.ini
Normal file
74
algorithm/default_config.ini
Normal file
@@ -0,0 +1,74 @@
|
||||
[basic]
|
||||
num_inputs = 2
|
||||
num_outputs = 1
|
||||
maximum_nodes = 5
|
||||
maximum_connections = 5
|
||||
maximum_species = 10
|
||||
forward_way = "pop"
|
||||
batch_size = 4
|
||||
random_seed = 0
|
||||
network_type = 'feedforward'
|
||||
|
||||
[population]
|
||||
fitness_threshold = 3.99999
|
||||
generation_limit = 1000
|
||||
fitness_criterion = "max"
|
||||
pop_size = 1000
|
||||
|
||||
[gene]
|
||||
gene_type = "normal"
|
||||
|
||||
[genome]
|
||||
compatibility_disjoint = 1.0
|
||||
compatibility_weight = 0.5
|
||||
conn_add_prob = 0.4
|
||||
conn_add_trials = 1
|
||||
conn_delete_prob = 0.4
|
||||
node_add_prob = 0.2
|
||||
node_delete_prob = 0.2
|
||||
|
||||
[species]
|
||||
compatibility_threshold = 3.0
|
||||
species_elitism = 2
|
||||
max_stagnation = 15
|
||||
genome_elitism = 2
|
||||
survival_threshold = 0.2
|
||||
min_species_size = 1
|
||||
spawn_number_move_rate = 0.5
|
||||
|
||||
[gene-bias]
|
||||
bias_init_mean = 0.0
|
||||
bias_init_std = 1.0
|
||||
bias_mutate_power = 0.5
|
||||
bias_mutate_rate = 0.7
|
||||
bias_replace_rate = 0.1
|
||||
|
||||
[gene-response]
|
||||
response_init_mean = 1.0
|
||||
response_init_std = 0.0
|
||||
response_mutate_power = 0.0
|
||||
response_mutate_rate = 0.0
|
||||
response_replace_rate = 0.0
|
||||
|
||||
[gene-activation]
|
||||
activation_default = "sigmoid"
|
||||
activation_option_names = ["sigmoid"]
|
||||
activation_replace_rate = 0.0
|
||||
|
||||
[gene-aggregation]
|
||||
aggregation_default = "sum"
|
||||
aggregation_option_names = ["sum"]
|
||||
aggregation_replace_rate = 0.0
|
||||
|
||||
[gene-weight]
|
||||
weight_init_mean = 0.0
|
||||
weight_init_std = 1.0
|
||||
weight_mutate_power = 0.5
|
||||
weight_mutate_rate = 0.8
|
||||
weight_replace_rate = 0.1
|
||||
|
||||
[gene-enable]
|
||||
enable_mutate_rate = 0.01
|
||||
|
||||
[visualize]
|
||||
renumber_nodes = True
|
||||
42
algorithm/neat/NEAT.py
Normal file
42
algorithm/neat/NEAT.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from algorithm.state import State
|
||||
from .gene import *
|
||||
from .genome import initialize_genomes
|
||||
|
||||
|
||||
class NEAT:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
if self.config['gene_type'] == 'normal':
|
||||
self.gene_type = NormalGene
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def setup(self, randkey):
|
||||
|
||||
state = State(
|
||||
randkey=randkey,
|
||||
P=self.config['pop_size'],
|
||||
N=self.config['maximum_nodes'],
|
||||
C=self.config['maximum_connections'],
|
||||
S=self.config['maximum_species'],
|
||||
NL=1 + len(self.gene_type.node_attrs), # node length = (key) + attributes
|
||||
CL=3 + len(self.gene_type.conn_attrs), # conn length = (in, out, key) + attributes
|
||||
input_idx=self.config['input_idx'],
|
||||
output_idx=self.config['output_idx']
|
||||
)
|
||||
|
||||
pop_nodes, pop_conns = initialize_genomes(state, self.gene_type)
|
||||
next_node_key = max(*state.input_idx, *state.output_idx) + 2
|
||||
state = state.update(
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
next_node_key=next_node_key
|
||||
)
|
||||
|
||||
return state
|
||||
|
||||
def tell(self, state, fitness):
|
||||
return State()
|
||||
|
||||
def ask(self, state):
|
||||
return State()
|
||||
@@ -0,0 +1 @@
|
||||
from .NEAT import NEAT
|
||||
|
||||
2
algorithm/neat/gene/__init__.py
Normal file
2
algorithm/neat/gene/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .base import BaseGene
|
||||
from .normal import NormalGene
|
||||
108
algorithm/neat/gene/activation.py
Normal file
108
algorithm/neat/gene/activation.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class Activation:
|
||||
|
||||
@staticmethod
|
||||
def sigmoid_act(z):
|
||||
z = jnp.clip(z * 5, -60, 60)
|
||||
return 1 / (1 + jnp.exp(-z))
|
||||
|
||||
@staticmethod
|
||||
def tanh_act(z):
|
||||
z = jnp.clip(z * 2.5, -60, 60)
|
||||
return jnp.tanh(z)
|
||||
|
||||
@staticmethod
|
||||
def sin_act(z):
|
||||
z = jnp.clip(z * 5, -60, 60)
|
||||
return jnp.sin(z)
|
||||
|
||||
@staticmethod
|
||||
def gauss_act(z):
|
||||
z = jnp.clip(z * 5, -3.4, 3.4)
|
||||
return jnp.exp(-z ** 2)
|
||||
|
||||
@staticmethod
|
||||
def relu_act(z):
|
||||
return jnp.maximum(z, 0)
|
||||
|
||||
@staticmethod
|
||||
def elu_act(z):
|
||||
return jnp.where(z > 0, z, jnp.exp(z) - 1)
|
||||
|
||||
@staticmethod
|
||||
def lelu_act(z):
|
||||
leaky = 0.005
|
||||
return jnp.where(z > 0, z, leaky * z)
|
||||
|
||||
@staticmethod
|
||||
def selu_act(z):
|
||||
lam = 1.0507009873554804934193349852946
|
||||
alpha = 1.6732632423543772848170429916717
|
||||
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
|
||||
|
||||
@staticmethod
|
||||
def softplus_act(z):
|
||||
z = jnp.clip(z * 5, -60, 60)
|
||||
return 0.2 * jnp.log(1 + jnp.exp(z))
|
||||
|
||||
@staticmethod
|
||||
def identity_act(z):
|
||||
return z
|
||||
|
||||
@staticmethod
|
||||
def clamped_act(z):
|
||||
return jnp.clip(z, -1, 1)
|
||||
|
||||
@staticmethod
|
||||
def inv_act(z):
|
||||
z = jnp.maximum(z, 1e-7)
|
||||
return 1 / z
|
||||
|
||||
@staticmethod
|
||||
def log_act(z):
|
||||
z = jnp.maximum(z, 1e-7)
|
||||
return jnp.log(z)
|
||||
|
||||
@staticmethod
|
||||
def exp_act(z):
|
||||
z = jnp.clip(z, -60, 60)
|
||||
return jnp.exp(z)
|
||||
|
||||
@staticmethod
|
||||
def abs_act(z):
|
||||
return jnp.abs(z)
|
||||
|
||||
@staticmethod
|
||||
def hat_act(z):
|
||||
return jnp.maximum(0, 1 - jnp.abs(z))
|
||||
|
||||
@staticmethod
|
||||
def square_act(z):
|
||||
return z ** 2
|
||||
|
||||
@staticmethod
|
||||
def cube_act(z):
|
||||
return z ** 3
|
||||
|
||||
name2func = {
|
||||
'sigmoid': sigmoid_act,
|
||||
'tanh': tanh_act,
|
||||
'sin': sin_act,
|
||||
'gauss': gauss_act,
|
||||
'relu': relu_act,
|
||||
'elu': elu_act,
|
||||
'lelu': lelu_act,
|
||||
'selu': selu_act,
|
||||
'softplus': softplus_act,
|
||||
'identity': identity_act,
|
||||
'clamped': clamped_act,
|
||||
'inv': inv_act,
|
||||
'log': log_act,
|
||||
'exp': exp_act,
|
||||
'abs': abs_act,
|
||||
'hat': hat_act,
|
||||
'square': square_act,
|
||||
'cube': cube_act,
|
||||
}
|
||||
60
algorithm/neat/gene/aggregation.py
Normal file
60
algorithm/neat/gene/aggregation.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class Aggregation:
|
||||
|
||||
@staticmethod
|
||||
def sum_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
return jnp.sum(z, axis=0)
|
||||
|
||||
@staticmethod
|
||||
def product_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), 1, z)
|
||||
return jnp.prod(z, axis=0)
|
||||
|
||||
@staticmethod
|
||||
def max_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
|
||||
return jnp.max(z, axis=0)
|
||||
|
||||
@staticmethod
|
||||
def min_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), jnp.inf, z)
|
||||
return jnp.min(z, axis=0)
|
||||
|
||||
@staticmethod
|
||||
def maxabs_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
abs_z = jnp.abs(z)
|
||||
max_abs_index = jnp.argmax(abs_z)
|
||||
return z[max_abs_index]
|
||||
|
||||
@staticmethod
|
||||
def median_agg(z):
|
||||
n = jnp.sum(~jnp.isnan(z), axis=0)
|
||||
|
||||
z = jnp.sort(z) # sort
|
||||
|
||||
idx1, idx2 = (n - 1) // 2, n // 2
|
||||
median = (z[idx1] + z[idx2]) / 2
|
||||
|
||||
return median
|
||||
|
||||
@staticmethod
|
||||
def mean_agg(z):
|
||||
aux = jnp.where(jnp.isnan(z), 0, z)
|
||||
valid_values_sum = jnp.sum(aux, axis=0)
|
||||
valid_values_count = jnp.sum(~jnp.isnan(z), axis=0)
|
||||
mean_without_zeros = valid_values_sum / valid_values_count
|
||||
return mean_without_zeros
|
||||
|
||||
name2func = {
|
||||
'sum': sum_agg,
|
||||
'product': product_agg,
|
||||
'max': max_agg,
|
||||
'min': min_agg,
|
||||
'maxabs': maxabs_agg,
|
||||
'median': median_agg,
|
||||
'mean': mean_agg,
|
||||
}
|
||||
38
algorithm/neat/gene/base.py
Normal file
38
algorithm/neat/gene/base.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from jax import Array, numpy as jnp
|
||||
|
||||
|
||||
class BaseGene:
|
||||
node_attrs = []
|
||||
conn_attrs = []
|
||||
|
||||
@staticmethod
|
||||
def setup(state, config):
|
||||
return state
|
||||
|
||||
@staticmethod
|
||||
def new_node_attrs(state):
|
||||
return jnp.zeros(0)
|
||||
|
||||
@staticmethod
|
||||
def new_conn_attrs(state):
|
||||
return jnp.zeros(0)
|
||||
|
||||
@staticmethod
|
||||
def mutate_node(state, attrs: Array, key):
|
||||
return attrs
|
||||
|
||||
@staticmethod
|
||||
def mutate_conn(state, attrs: Array, key):
|
||||
return attrs
|
||||
|
||||
@staticmethod
|
||||
def distance_node(state, array: Array):
|
||||
return array
|
||||
|
||||
@staticmethod
|
||||
def distance_conn(state, array: Array):
|
||||
return array
|
||||
|
||||
@staticmethod
|
||||
def forward(state, array: Array):
|
||||
return array
|
||||
40
algorithm/neat/gene/normal.py
Normal file
40
algorithm/neat/gene/normal.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from jax import Array, numpy as jnp
|
||||
|
||||
from . import BaseGene
|
||||
|
||||
|
||||
class NormalGene(BaseGene):
|
||||
node_attrs = ['bias', 'response', 'aggregation', 'activation']
|
||||
conn_attrs = ['weight']
|
||||
|
||||
@staticmethod
|
||||
def setup(state, config):
|
||||
return state
|
||||
|
||||
@staticmethod
|
||||
def new_node_attrs(state):
|
||||
return jnp.array([0, 0, 0, 0])
|
||||
|
||||
@staticmethod
|
||||
def new_conn_attrs(state):
|
||||
return jnp.array([0])
|
||||
|
||||
@staticmethod
|
||||
def mutate_node(state, attrs: Array, key):
|
||||
return attrs
|
||||
|
||||
@staticmethod
|
||||
def mutate_conn(state, attrs: Array, key):
|
||||
return attrs
|
||||
|
||||
@staticmethod
|
||||
def distance_node(state, array: Array):
|
||||
return array
|
||||
|
||||
@staticmethod
|
||||
def distance_conn(state, array: Array):
|
||||
return array
|
||||
|
||||
@staticmethod
|
||||
def forward(state, array: Array):
|
||||
return array
|
||||
2
algorithm/neat/genome/__init__.py
Normal file
2
algorithm/neat/genome/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .basic import initialize_genomes
|
||||
from .mutate import create_mutate
|
||||
102
algorithm/neat/genome/basic.py
Normal file
102
algorithm/neat/genome/basic.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from typing import Type, Tuple
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import Array, numpy as jnp
|
||||
|
||||
from algorithm import State
|
||||
from ..gene import BaseGene
|
||||
from ..utils import fetch_first
|
||||
|
||||
|
||||
def initialize_genomes(state: State, gene_type: Type[BaseGene]):
|
||||
o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes
|
||||
o_conns = np.full((state.N, state.CL), np.nan, dtype=np.float32) # original connections
|
||||
|
||||
input_idx = state.input_idx
|
||||
output_idx = state.output_idx
|
||||
new_node_key = max([*input_idx, *output_idx]) + 1
|
||||
|
||||
o_nodes[input_idx, 0] = input_idx
|
||||
o_nodes[output_idx, 0] = output_idx
|
||||
o_nodes[new_node_key, 0] = new_node_key
|
||||
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = jax.device_get(gene_type.new_node_attrs(state))
|
||||
o_nodes[new_node_key, 1:] = jax.device_get(gene_type.new_node_attrs(state))
|
||||
|
||||
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)]
|
||||
o_conns[input_idx, 0:2] = input_conns # in key, out key
|
||||
o_conns[input_idx, 2] = True # enabled
|
||||
o_conns[input_idx, 3:] = jax.device_get(gene_type.new_conn_attrs(state))
|
||||
|
||||
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx]
|
||||
o_conns[output_idx, 0:2] = output_conns # in key, out key
|
||||
o_conns[output_idx, 2] = True # enabled
|
||||
o_conns[output_idx, 3:] = jax.device_get(gene_type.new_conn_attrs(state))
|
||||
|
||||
# repeat origin genome for P times to create population
|
||||
pop_nodes = np.tile(o_nodes, (state.P, 1, 1))
|
||||
pop_conns = np.tile(o_conns, (state.P, 1, 1))
|
||||
|
||||
return pop_nodes, pop_conns
|
||||
|
||||
|
||||
def add_node(nodes: Array, cons: Array, new_key: int, attrs: Array) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Add a new node to the genome.
|
||||
The new node will place at the first NaN row.
|
||||
"""
|
||||
exist_keys = nodes[:, 0]
|
||||
idx = fetch_first(jnp.isnan(exist_keys))
|
||||
nodes = nodes.at[idx, 0].set(new_key)
|
||||
nodes = nodes.at[idx, 1:].set(attrs)
|
||||
return nodes, cons
|
||||
|
||||
|
||||
def delete_node(nodes: Array, cons: Array, node_key: Array) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Delete a node from the genome. Only delete the node, regardless of connections.
|
||||
Delete the node by its key.
|
||||
"""
|
||||
node_keys = nodes[:, 0]
|
||||
idx = fetch_first(node_keys == node_key)
|
||||
return delete_node_by_idx(nodes, cons, idx)
|
||||
|
||||
|
||||
def delete_node_by_idx(nodes: Array, cons: Array, idx: Array) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Delete a node from the genome. Only delete the node, regardless of connections.
|
||||
Delete the node by its idx.
|
||||
"""
|
||||
nodes = nodes.at[idx].set(np.nan)
|
||||
return nodes, cons
|
||||
|
||||
|
||||
def add_connection(nodes: Array, cons: Array, i_key: Array, o_key: Array, enable: bool, attrs: Array) -> Tuple[
|
||||
Array, Array]:
|
||||
"""
|
||||
Add a new connection to the genome.
|
||||
The new connection will place at the first NaN row.
|
||||
"""
|
||||
con_keys = cons[:, 0]
|
||||
idx = fetch_first(jnp.isnan(con_keys))
|
||||
cons = cons.at[idx, 0:3].set(jnp.array([i_key, o_key, enable]))
|
||||
cons = cons.at[idx, 3:].set(attrs)
|
||||
return nodes, cons
|
||||
|
||||
|
||||
def delete_connection(nodes: Array, cons: Array, i_key: Array, o_key: Array) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Delete a connection from the genome.
|
||||
Delete the connection by its input and output node keys.
|
||||
"""
|
||||
idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
|
||||
return delete_connection_by_idx(nodes, cons, idx)
|
||||
|
||||
|
||||
def delete_connection_by_idx(nodes: Array, cons: Array, idx: Array) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Delete a connection from the genome.
|
||||
Delete the connection by its idx.
|
||||
"""
|
||||
cons = cons.at[idx].set(np.nan)
|
||||
return nodes, cons
|
||||
0
algorithm/neat/genome/crossover.py
Normal file
0
algorithm/neat/genome/crossover.py
Normal file
0
algorithm/neat/genome/distance.py
Normal file
0
algorithm/neat/genome/distance.py
Normal file
167
algorithm/neat/genome/graph.py
Normal file
167
algorithm/neat/genome/graph.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
Some graph algorithm implemented in jax.
|
||||
Only used in feed-forward networks.
|
||||
"""
|
||||
|
||||
import jax
|
||||
from jax import jit, Array, numpy as jnp
|
||||
|
||||
from ..utils import fetch_first, I_INT
|
||||
|
||||
|
||||
@jit
|
||||
def topological_sort(nodes: Array, connections: Array) -> Array:
|
||||
"""
|
||||
a jit-able version of topological_sort! that's crazy!
|
||||
:param nodes: nodes array
|
||||
:param connections: connections array
|
||||
:return: topological sorted sequence
|
||||
|
||||
Example:
|
||||
nodes = jnp.array([
|
||||
[0],
|
||||
[1],
|
||||
[2],
|
||||
[3]
|
||||
])
|
||||
connections = jnp.array([
|
||||
[
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 1, 1],
|
||||
[0, 0, 0, 1],
|
||||
[0, 0, 0, 0]
|
||||
],
|
||||
[
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 1, 1],
|
||||
[0, 0, 0, 1],
|
||||
[0, 0, 0, 0]
|
||||
]
|
||||
])
|
||||
|
||||
topological_sort(nodes, connections) -> [0, 1, 2, 3]
|
||||
"""
|
||||
connections_enable = connections[1, :, :] == 1 # forward function. thus use enable
|
||||
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0))
|
||||
res = jnp.full(in_degree.shape, I_INT)
|
||||
|
||||
def cond_fun(carry):
|
||||
res_, idx_, in_degree_ = carry
|
||||
i = fetch_first(in_degree_ == 0.)
|
||||
return i != I_INT
|
||||
|
||||
def body_func(carry):
|
||||
res_, idx_, in_degree_ = carry
|
||||
i = fetch_first(in_degree_ == 0.)
|
||||
|
||||
# add to res and flag it is already in it
|
||||
res_ = res_.at[idx_].set(i)
|
||||
in_degree_ = in_degree_.at[i].set(-1)
|
||||
|
||||
# decrease in_degree of all its children
|
||||
children = connections_enable[i, :]
|
||||
in_degree_ = jnp.where(children, in_degree_ - 1, in_degree_)
|
||||
return res_, idx_ + 1, in_degree_
|
||||
|
||||
res, _, _ = jax.lax.while_loop(cond_fun, body_func, (res, 0, in_degree))
|
||||
return res
|
||||
|
||||
|
||||
@jit
|
||||
def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array:
|
||||
"""
|
||||
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
|
||||
|
||||
:param nodes: JAX array
|
||||
The array of nodes.
|
||||
:param connections: JAX array
|
||||
The array of connections.
|
||||
:param from_idx: int
|
||||
The index of the starting node.
|
||||
:param to_idx: int
|
||||
The index of the ending node.
|
||||
:return: JAX array
|
||||
An array indicating if there is a cycle caused by the new connection.
|
||||
|
||||
Example:
|
||||
nodes = jnp.array([
|
||||
[0],
|
||||
[1],
|
||||
[2],
|
||||
[3]
|
||||
])
|
||||
connections = jnp.array([
|
||||
[
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 1, 1],
|
||||
[0, 0, 0, 1],
|
||||
[0, 0, 0, 0]
|
||||
],
|
||||
[
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 1, 1],
|
||||
[0, 0, 0, 1],
|
||||
[0, 0, 0, 0]
|
||||
]
|
||||
])
|
||||
|
||||
check_cycles(nodes, connections, 3, 2) -> True
|
||||
check_cycles(nodes, connections, 2, 3) -> False
|
||||
check_cycles(nodes, connections, 0, 3) -> False
|
||||
check_cycles(nodes, connections, 1, 0) -> False
|
||||
"""
|
||||
|
||||
connections_enable = ~jnp.isnan(connections[0, :, :])
|
||||
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
|
||||
|
||||
visited = jnp.full(nodes.shape[0], False)
|
||||
new_visited = visited.at[to_idx].set(True)
|
||||
|
||||
def cond_func(carry):
|
||||
visited_, new_visited_ = carry
|
||||
end_cond1 = jnp.all(visited_ == new_visited_) # no new nodes been visited
|
||||
end_cond2 = new_visited_[from_idx] # the starting node has been visited
|
||||
return jnp.logical_not(end_cond1 | end_cond2)
|
||||
|
||||
def body_func(carry):
|
||||
_, visited_ = carry
|
||||
new_visited_ = jnp.dot(visited_, connections_enable)
|
||||
new_visited_ = jnp.logical_or(visited_, new_visited_)
|
||||
return visited_, new_visited_
|
||||
|
||||
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
|
||||
return visited[from_idx]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
nodes = jnp.array([
|
||||
[0],
|
||||
[1],
|
||||
[2],
|
||||
[3],
|
||||
[jnp.nan]
|
||||
])
|
||||
connections = jnp.array([
|
||||
[
|
||||
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
|
||||
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
||||
],
|
||||
[
|
||||
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
|
||||
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
print(topological_sort(nodes, connections))
|
||||
|
||||
print(check_cycles(nodes, connections, 3, 2))
|
||||
print(check_cycles(nodes, connections, 2, 3))
|
||||
print(check_cycles(nodes, connections, 0, 3))
|
||||
print(check_cycles(nodes, connections, 1, 0))
|
||||
206
algorithm/neat/genome/mutate.py
Normal file
206
algorithm/neat/genome/mutate.py
Normal file
@@ -0,0 +1,206 @@
|
||||
from typing import Dict, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import Array, numpy as jnp, vmap
|
||||
|
||||
from algorithm import State
|
||||
from .basic import add_node, add_connection, delete_node_by_idx, delete_connection_by_idx
|
||||
from .graph import check_cycles
|
||||
from ..utils import fetch_random, fetch_first, I_INT, unflatten_connections
|
||||
from ..gene import BaseGene
|
||||
|
||||
|
||||
def create_mutate(config: Dict, gene_type: Type[BaseGene]):
|
||||
"""
|
||||
Create function to mutate the whole population
|
||||
"""
|
||||
|
||||
def mutate_structure(state: State, randkey, nodes, cons, new_node_key):
|
||||
def nothing(*args):
|
||||
return nodes, cons
|
||||
|
||||
def mutate_add_node(key_):
|
||||
i_key, o_key, idx = choice_connection_key(key_, nodes, cons)
|
||||
|
||||
def successful_add_node():
|
||||
# disable the connection
|
||||
aux_nodes, aux_cons = nodes, cons
|
||||
|
||||
# set enable to false
|
||||
aux_cons = aux_cons.at[idx, 2].set(False)
|
||||
|
||||
# add a new node
|
||||
aux_nodes, aux_cons = add_node(aux_nodes, aux_cons, new_node_key, gene_type.new_node_attrs(state))
|
||||
|
||||
# add two new connections
|
||||
aux_nodes, aux_cons = add_connection(aux_nodes, aux_cons, i_key, new_node_key, True,
|
||||
gene_type.new_conn_attrs(state))
|
||||
aux_nodes, aux_cons = add_connection(aux_nodes, aux_cons, new_node_key, o_key, True,
|
||||
gene_type.new_conn_attrs(state))
|
||||
|
||||
return aux_nodes, aux_cons
|
||||
|
||||
# if from_idx == I_INT, that means no connection exist, do nothing
|
||||
return jax.lax.cond(idx == I_INT, nothing, successful_add_node)
|
||||
|
||||
def mutate_delete_node(key_):
|
||||
# TODO: Do we really need to delete a node?
|
||||
# randomly choose a node
|
||||
key, idx = choice_node_key(key_, nodes, config['input_idx'], config['output_idx'],
|
||||
allow_input_keys=False, allow_output_keys=False)
|
||||
|
||||
def successful_delete_node():
|
||||
# delete the node
|
||||
aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, idx)
|
||||
|
||||
# delete all connections
|
||||
aux_cons = jnp.where(((aux_cons[:, 0] == key) | (aux_cons[:, 1] == key))[:, None],
|
||||
jnp.nan, aux_cons)
|
||||
|
||||
return aux_nodes, aux_cons
|
||||
|
||||
return jax.lax.cond(idx == I_INT, nothing, successful_delete_node)
|
||||
|
||||
def mutate_add_conn(key_):
|
||||
# randomly choose two nodes
|
||||
k1_, k2_ = jax.random.split(key_, num=2)
|
||||
i_key, from_idx = choice_node_key(k1_, nodes, config['input_idx'], config['output_idx'],
|
||||
allow_input_keys=True, allow_output_keys=True)
|
||||
o_key, to_idx = choice_node_key(k2_, nodes, config['input_idx'], config['output_idx'],
|
||||
allow_input_keys=False, allow_output_keys=True)
|
||||
|
||||
con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
|
||||
|
||||
def successful():
|
||||
new_nodes, new_cons = add_connection(nodes, cons, i_key, o_key, True, gene_type.new_conn_attrs(state))
|
||||
return new_nodes, new_cons
|
||||
|
||||
def already_exist():
|
||||
new_cons = cons.at[con_idx, 2].set(True)
|
||||
return nodes, new_cons
|
||||
|
||||
is_already_exist = con_idx != I_INT
|
||||
|
||||
if config['network_type'] == 'feedforward':
|
||||
u_cons = unflatten_connections(nodes, cons)
|
||||
is_cycle = check_cycles(nodes, u_cons, from_idx, to_idx)
|
||||
|
||||
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
|
||||
return jax.lax.switch(choice, [already_exist, nothing, successful])
|
||||
|
||||
elif config['network_type'] == 'recurrent':
|
||||
return jax.lax.cond(is_already_exist, already_exist, successful)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid network type: {config['network_type']}")
|
||||
|
||||
def mutate_delete_conn(key_):
|
||||
# randomly choose a connection
|
||||
i_key, o_key, idx = choice_connection_key(key_, nodes, cons)
|
||||
|
||||
def successfully_delete_connection():
|
||||
return delete_connection_by_idx(nodes, cons, idx)
|
||||
|
||||
return jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection)
|
||||
|
||||
k, k1, k2, k3, k4 = jax.random.split(randkey, num=5)
|
||||
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
|
||||
|
||||
nodes, cons = jax.lax.cond(r1 < config['node_add_prob'], mutate_add_node, nothing, k1)
|
||||
nodes, cons = jax.lax.cond(r2 < config['node_delete_prob'], mutate_delete_node, nothing, k2)
|
||||
nodes, cons = jax.lax.cond(r3 < config['conn_add_prob'], mutate_add_conn, nothing, k3)
|
||||
nodes, cons = jax.lax.cond(r4 < config['conn_delete_prob'], mutate_delete_conn, nothing, k4)
|
||||
return nodes, cons
|
||||
|
||||
def mutate_values(state: State, randkey, nodes, conns):
|
||||
k1, k2 = jax.random.split(randkey, num=2)
|
||||
nodes_keys = jax.random.split(k1, num=nodes.shape[0])
|
||||
conns_keys = jax.random.split(k2, num=conns.shape[0])
|
||||
|
||||
nodes_attrs, conns_attrs = nodes[:, 1:], conns[:, 3:]
|
||||
|
||||
new_nodes_attrs = vmap(gene_type.mutate_node, in_axes=(None, 0, 0))(state, nodes_attrs, nodes_keys)
|
||||
new_conns_attrs = vmap(gene_type.mutate_conn, in_axes=(None, 0, 0))(state, conns_attrs, conns_keys)
|
||||
|
||||
# nan nodes not changed
|
||||
new_nodes_attrs = jnp.where(jnp.isnan(nodes_attrs), jnp.nan, new_nodes_attrs)
|
||||
new_conns_attrs = jnp.where(jnp.isnan(conns_attrs), jnp.nan, new_conns_attrs)
|
||||
|
||||
new_nodes = nodes.at[:, 1:].set(new_nodes_attrs)
|
||||
new_conns = conns.at[:, 3:].set(new_conns_attrs)
|
||||
|
||||
return new_nodes, new_conns
|
||||
|
||||
def mutate(state):
|
||||
pop_nodes, pop_conns = state.pop_nodes, state.pop_conns
|
||||
pop_size = pop_nodes.shape[0]
|
||||
|
||||
new_node_keys = jnp.arange(pop_size) + state.next_node_key
|
||||
k1, k2, randkey = jax.random.split(state.randkey, num=3)
|
||||
structure_randkeys = jax.random.split(k1, num=pop_size)
|
||||
values_randkeys = jax.random.split(k2, num=pop_size)
|
||||
|
||||
structure_func = jax.vmap(mutate_structure, in_axes=(None, 0, 0, 0, 0))
|
||||
pop_nodes, pop_conns = structure_func(state, structure_randkeys, pop_nodes, pop_conns, new_node_keys)
|
||||
|
||||
values_func = jax.vmap(mutate_values, in_axes=(None, 0, 0, 0))
|
||||
pop_nodes, pop_conns = values_func(state, values_randkeys, pop_nodes, pop_conns)
|
||||
|
||||
# update next node key
|
||||
all_nodes_keys = pop_nodes[:, :, 0]
|
||||
max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys))
|
||||
next_node_key = max_node_key + 1
|
||||
|
||||
return state.update(
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
next_node_key=next_node_key,
|
||||
randkey=randkey
|
||||
)
|
||||
|
||||
return mutate
|
||||
|
||||
|
||||
def choice_node_key(rand_key: Array, nodes: Array,
|
||||
input_keys: Array, output_keys: Array,
|
||||
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
|
||||
:param rand_key:
|
||||
:param nodes:
|
||||
:param input_keys:
|
||||
:param output_keys:
|
||||
:param allow_input_keys:
|
||||
:param allow_output_keys:
|
||||
:return: return its key and position(idx)
|
||||
"""
|
||||
|
||||
node_keys = nodes[:, 0]
|
||||
mask = ~jnp.isnan(node_keys)
|
||||
|
||||
if not allow_input_keys:
|
||||
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys))
|
||||
|
||||
if not allow_output_keys:
|
||||
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys))
|
||||
|
||||
idx = fetch_random(rand_key, mask)
|
||||
key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan)
|
||||
return key, idx
|
||||
|
||||
|
||||
def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[Array, Array, Array]:
|
||||
"""
|
||||
Randomly choose a connection key from the given connections.
|
||||
:param rand_key:
|
||||
:param nodes:
|
||||
:param cons:
|
||||
:return: i_key, o_key, idx
|
||||
"""
|
||||
|
||||
idx = fetch_random(rand_key, ~jnp.isnan(cons[:, 0]))
|
||||
i_key = jnp.where(idx != I_INT, cons[idx, 0], jnp.nan)
|
||||
o_key = jnp.where(idx != I_INT, cons[idx, 1], jnp.nan)
|
||||
|
||||
return i_key, o_key, idx
|
||||
71
algorithm/neat/utils.py
Normal file
71
algorithm/neat/utils.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import numpy as jnp, Array, jit, vmap
|
||||
|
||||
I_INT = np.iinfo(jnp.int32).max # infinite int
|
||||
EMPTY_NODE = np.full((1, 5), jnp.nan)
|
||||
EMPTY_CON = np.full((1, 4), jnp.nan)
|
||||
|
||||
|
||||
@jit
|
||||
def unflatten_connections(nodes: Array, cons: Array):
|
||||
"""
|
||||
transform the (C, 4) connections to (2, N, N)
|
||||
:param nodes: (N, 5)
|
||||
:param cons: (C, 4)
|
||||
:return:
|
||||
"""
|
||||
N = nodes.shape[0]
|
||||
node_keys = nodes[:, 0]
|
||||
i_keys, o_keys = cons[:, 0], cons[:, 1]
|
||||
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
|
||||
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
|
||||
res = jnp.full((2, N, N), jnp.nan)
|
||||
|
||||
# Is interesting that jax use clip when attach data in array
|
||||
# however, it will do nothing set values in an array
|
||||
res = res.at[0, i_idxs, o_idxs].set(cons[:, 2])
|
||||
res = res.at[1, i_idxs, o_idxs].set(cons[:, 3])
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def key_to_indices(key, keys):
|
||||
return fetch_first(key == keys)
|
||||
|
||||
|
||||
@jit
|
||||
def fetch_first(mask, default=I_INT) -> Array:
|
||||
"""
|
||||
fetch the first True index
|
||||
:param mask: array of bool
|
||||
:param default: the default value if no element satisfying the condition
|
||||
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return default value
|
||||
"""
|
||||
idx = jnp.argmax(mask)
|
||||
return jnp.where(mask[idx], idx, default)
|
||||
|
||||
|
||||
@jit
|
||||
def fetch_random(rand_key, mask, default=I_INT) -> Array:
|
||||
"""
|
||||
similar to fetch_first, but fetch a random True index
|
||||
"""
|
||||
true_cnt = jnp.sum(mask)
|
||||
cumsum = jnp.cumsum(mask)
|
||||
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
|
||||
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
|
||||
return fetch_first(mask, default)
|
||||
|
||||
|
||||
@partial(jit, static_argnames=['reverse'])
|
||||
def rank_elements(array, reverse=False):
|
||||
"""
|
||||
rank the element in the array.
|
||||
if reverse is True, the rank is from small to large. default large to small
|
||||
"""
|
||||
if not reverse:
|
||||
array = -array
|
||||
return jnp.argsort(jnp.argsort(array))
|
||||
@@ -1,4 +1,4 @@
|
||||
from jax.tree_util import register_pytree_node_class, tree_map
|
||||
from jax.tree_util import register_pytree_node_class
|
||||
|
||||
|
||||
@register_pytree_node_class
|
||||
@@ -20,10 +20,12 @@ class State:
|
||||
return f"State ({self.state_dict})"
|
||||
|
||||
def tree_flatten(self):
|
||||
print('tree_flatten_cal')
|
||||
children = list(self.state_dict.values())
|
||||
aux_data = list(self.state_dict.keys())
|
||||
return children, aux_data
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
print('tree_unflatten_cal')
|
||||
return cls(**dict(zip(aux_data, children)))
|
||||
|
||||
Reference in New Issue
Block a user