create state

This commit is contained in:
wls2002
2023-07-14 17:27:22 +08:00
parent 7265e33c43
commit a0a1ef6c58
41 changed files with 43 additions and 2882 deletions

View File

@@ -1,12 +0,0 @@
# NEATAX: Tensorized NEAT Implementation for Parallel Hardware Accelaration
NEATAX is a powerful tool that utilizes JAX to implement the NEAT (NeuroEvolution of Augmenting Topologies) algorithm. It provides support for parallel execution of tasks such as forward network computation, mutation, and crossover at the population level.
## Performance
One of the standout features of NEATAX is its speed. Compared to traditional CPU implementations, NEATAX significantly improves the efficiency of the NEAT algorithm. It has been observed to boost the running speed of the NEAT algorithm by more than 10 times, offering considerable advantage in larger-scale and time-sensitive applications.
## Installization
by git clone
need JAX environment

View File

29
algorithm/state.py Normal file
View File

@@ -0,0 +1,29 @@
from jax.tree_util import register_pytree_node_class, tree_map
@register_pytree_node_class
class State:
def __init__(self, **kwargs):
self.__dict__['state_dict'] = kwargs
def update(self, **kwargs):
return State(**{**self.state_dict, **kwargs})
def __getattr__(self, name):
return self.state_dict[name]
def __setattr__(self, name, value):
raise AttributeError("State is immutable")
def __repr__(self):
return f"State ({self.state_dict})"
def tree_flatten(self):
children = list(self.state_dict.values())
aux_data = list(self.state_dict.keys())
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(**dict(zip(aux_data, children)))

View File

@@ -1,10 +0,0 @@
"""
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
"""
from .genome import create_forward_function, topological_sort, unflatten_connections, initialize_genomes
from .population import update_species, create_next_generation, speciate, tell, initialize
from .genome.activations import act_name2func
from .genome.aggregations import agg_name2func
from .visualize import Genome

View File

@@ -1,7 +0,0 @@
from .mutate import mutate
from .distance import distance
from .crossover import crossover
from .graph import topological_sort, check_cycles
from .utils import unflatten_connections, I_INT, fetch_first, rank_elements
from .forward import create_forward_function
from .genome import initialize_genomes

View File

@@ -1,106 +0,0 @@
import jax.numpy as jnp
def sigmoid_act(z):
z = jnp.clip(z * 5, -60, 60)
return 1 / (1 + jnp.exp(-z))
def tanh_act(z):
z = jnp.clip(z * 2.5, -60, 60)
return jnp.tanh(z)
def sin_act(z):
z = jnp.clip(z * 5, -60, 60)
return jnp.sin(z)
def gauss_act(z):
z = jnp.clip(z * 5, -3.4, 3.4)
return jnp.exp(-z ** 2)
def relu_act(z):
return jnp.maximum(z, 0)
def elu_act(z):
return jnp.where(z > 0, z, jnp.exp(z) - 1)
def lelu_act(z):
leaky = 0.005
return jnp.where(z > 0, z, leaky * z)
def selu_act(z):
lam = 1.0507009873554804934193349852946
alpha = 1.6732632423543772848170429916717
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
def softplus_act(z):
z = jnp.clip(z * 5, -60, 60)
return 0.2 * jnp.log(1 + jnp.exp(z))
def identity_act(z):
return z
def clamped_act(z):
return jnp.clip(z, -1, 1)
def inv_act(z):
z = jnp.maximum(z, 1e-7)
return 1 / z
def log_act(z):
z = jnp.maximum(z, 1e-7)
return jnp.log(z)
def exp_act(z):
z = jnp.clip(z, -60, 60)
return jnp.exp(z)
def abs_act(z):
return jnp.abs(z)
def hat_act(z):
return jnp.maximum(0, 1 - jnp.abs(z))
def square_act(z):
return z ** 2
def cube_act(z):
return z ** 3
act_name2func = {
'sigmoid': sigmoid_act,
'tanh': tanh_act,
'sin': sin_act,
'gauss': gauss_act,
'relu': relu_act,
'elu': elu_act,
'lelu': lelu_act,
'selu': selu_act,
'softplus': softplus_act,
'identity': identity_act,
'clamped': clamped_act,
'inv': inv_act,
'log': log_act,
'exp': exp_act,
'abs': abs_act,
'hat': hat_act,
'square': square_act,
'cube': cube_act,
}

View File

@@ -1,59 +0,0 @@
import jax.numpy as jnp
def sum_agg(z):
z = jnp.where(jnp.isnan(z), 0, z)
return jnp.sum(z, axis=0)
def product_agg(z):
z = jnp.where(jnp.isnan(z), 1, z)
return jnp.prod(z, axis=0)
def max_agg(z):
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
return jnp.max(z, axis=0)
def min_agg(z):
z = jnp.where(jnp.isnan(z), jnp.inf, z)
return jnp.min(z, axis=0)
def maxabs_agg(z):
z = jnp.where(jnp.isnan(z), 0, z)
abs_z = jnp.abs(z)
max_abs_index = jnp.argmax(abs_z)
return z[max_abs_index]
def median_agg(z):
non_nan_mask = ~jnp.isnan(z)
n = jnp.sum(non_nan_mask, axis=0)
z = jnp.sort(z) # sort
idx1, idx2 = (n - 1) // 2, n // 2
median = (z[idx1] + z[idx2]) / 2
return median
def mean_agg(z):
non_zero_mask = ~jnp.isnan(z)
valid_values_sum = sum_agg(z)
valid_values_count = jnp.sum(non_zero_mask, axis=0)
mean_without_zeros = valid_values_sum / valid_values_count
return mean_without_zeros
agg_name2func = {
'sum': sum_agg,
'product': product_agg,
'max': max_agg,
'min': min_agg,
'maxabs': maxabs_agg,
'median': median_agg,
'mean': mean_agg,
}

View File

@@ -1,80 +0,0 @@
"""
Crossover two genomes to generate a new genome.
The calculation method is the same as the crossover operation in NEAT-python.
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.configure_crossover
"""
from typing import Tuple
import jax
from jax import jit, Array, numpy as jnp
@jit
def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) -> Tuple[Array, Array]:
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
:param randkey:
:param nodes1:
:param cons1:
:param nodes2:
:param cons2:
:return:
"""
randkey_1, randkey_2 = jax.random.split(randkey)
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
# make homologous genes align in nodes2 align with nodes1
nodes2 = align_array(keys1, keys2, nodes2, 'node')
# For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
return new_nodes, new_cons
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
"""
After I review this code, I found that it is the most difficult part of the code. Please never change it!
make ar2 align with ar1.
:param seq1:
:param seq2:
:param ar2:
:param gene_type:
:return:
align means to intersect part of ar2 will be at the same position as ar1,
non-intersect part of ar2 will be set to Nan
"""
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
if gene_type == 'connection':
mask = jnp.all(mask, axis=2)
intersect_mask = mask.any(axis=1)
idx = jnp.arange(0, len(seq1))
idx_fixed = jnp.dot(mask, idx)
refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan)
return refactor_ar2
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
"""
crossover two genes
:param rand_key:
:param g1:
:param g2:
:return:
only gene with the same key will be crossover, thus don't need to consider change key
"""
r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2)

View File

@@ -1,118 +0,0 @@
"""
Calculate the distance between two genomes.
The calculation method is the same as the distance calculation in NEAT-python.
See https://github.com/CodeReclaimers/neat-python/blob/master/neat/genome.py
"""
from typing import Dict
from jax import jit, vmap, Array, numpy as jnp
from .utils import EMPTY_NODE, EMPTY_CON
@jit
def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, jit_config: Dict) -> Array:
"""
Calculate the distance between two genomes.
args:
nodes1: Array(N, 5)
cons1: Array(C, 4)
nodes2: Array(N, 5)
cons2: Array(C, 4)
returns:
distance: Array(, )
"""
nd = node_distance(nodes1, nodes2, jit_config) # node distance
cd = connection_distance(cons1, cons2, jit_config) # connection distance
return nd + cd
@jit
def node_distance(nodes1: Array, nodes2: Array, jit_config: Dict):
"""
Calculate the distance between nodes of two genomes.
"""
# statistics nodes count of two genomes
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
# align homologous nodes
# this process is similar to np.intersect1d.
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
# flag location of homologous nodes
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
# calculate the count of non_homologous of two genomes
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
# calculate the distance of homologous nodes
hnd = vmap(homologous_node_distance)(fr, sr)
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[
'compatibility_weight']
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
@jit
def connection_distance(cons1: Array, cons2: Array, jit_config: Dict):
"""
Calculate the distance between connections of two genomes.
Similar process as node_distance.
"""
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
cons = jnp.concatenate((cons1, cons2), axis=0)
keys = cons[:, :2]
sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate([cons, EMPTY_CON], axis=0) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
hcd = vmap(homologous_connection_distance)(fr, sr)
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask)
val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[
'compatibility_weight']
return jnp.where(max_cnt == 0, 0, val / max_cnt)
@jit
def homologous_node_distance(n1: Array, n2: Array):
"""
Calculate the distance between two homologous nodes.
"""
d = 0
d += jnp.abs(n1[1] - n2[1]) # bias
d += jnp.abs(n1[2] - n2[2]) # response
d += n1[3] != n2[3] # activation
d += n1[4] != n2[4] # aggregation
return d
@jit
def homologous_connection_distance(c1: Array, c2: Array):
"""
Calculate the distance between two homologous connections.
"""
d = 0
d += jnp.abs(c1[2] - c2[2]) # weight
d += c1[3] != c2[3] # enable
return d

View File

@@ -1,108 +0,0 @@
import jax
from jax import Array, numpy as jnp, jit, vmap
from .utils import I_INT
from .activations import act_name2func
from .aggregations import agg_name2func
def create_forward_function(config):
"""
meta method to create forward function
"""
config['activation_funcs'] = [act_name2func[name] for name in config['activation_option_names']]
config['aggregation_funcs'] = [agg_name2func[name] for name in config['aggregation_option_names']]
def act(idx, z):
"""
calculate activation function for each node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, config['activation_funcs'], z)
return res
def agg(idx, z):
"""
calculate activation function for inputs of node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
def all_nan():
return 0.
def not_all_nan():
return jax.lax.switch(idx, config['aggregation_funcs'], z)
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
def forward(inputs: Array, cal_seqs: Array, nodes: Array, cons: Array) -> Array:
"""
jax forward for single input shaped (input_num, )
nodes, connections are a single genome
:argument inputs: (input_num, )
:argument cal_seqs: (N, )
:argument nodes: (N, 5)
:argument connections: (2, N, N)
:return (output_num, )
"""
input_idx = config['input_idx']
output_idx = config['output_idx']
N = nodes.shape[0]
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
weights = jnp.where(jnp.isnan(cons[1, :, :]), jnp.nan, cons[0, :, :]) # enabled
def cond_fun(carry):
values, idx = carry
return (idx < N) & (cal_seqs[idx] != I_INT)
def body_func(carry):
values, idx = carry
i = cal_seqs[idx]
def hit():
ins = values * weights[:, i]
z = agg(nodes[i, 4], ins) # z = agg(ins)
z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
z = act(nodes[i, 3], z) # z = act(z)
new_values = values.at[i].set(z)
return new_values
def miss():
return values
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
return values, idx + 1
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
return vals[output_idx]
# (batch_size, inputs_nums) -> (batch_size, outputs_nums)
batch_forward = vmap(forward, in_axes=(0, None, None, None))
# (pop_size, batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
pop_batch_forward = vmap(batch_forward, in_axes=(0, 0, 0, 0))
# (batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0))
if config['forward_way'] == 'single':
return jit(forward)
if config['forward_way'] == 'batch':
return jit(batch_forward)
elif config['forward_way'] == 'pop':
return jit(pop_batch_forward)
elif config['forward_way'] == 'common':
return jit(common_forward)

View File

@@ -1,132 +0,0 @@
"""
Vectorization of genome representation.
Utilizes Tuple[nodes: Array(N, 5), connections: Array(C, 4)] to encode the genome, where:
nodes: [key, bias, response, act, agg]
connections: [in_key, out_key, weight, enable]
N: Maximum number of nodes in the network.
C: Maximum number of connections in the network.
"""
from typing import Tuple, Dict
import numpy as np
from numpy.typing import NDArray
from jax import jit, numpy as jnp
from .utils import fetch_first
def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
"""
Initialize genomes with default values.
Args:
N (int): Maximum number of nodes in the network.
C (int): Maximum number of connections in the network.
config (Dict): Configuration dictionary.
Returns:
Tuple[NDArray, NDArray, NDArray, NDArray]: pop_nodes, pop_connections, input_idx, and output_idx arrays.
"""
# Reserve one row for potential mutation adding an extra node
assert config['num_inputs'] + config['num_outputs'] + 1 <= N, \
f"Too small N: {N} for input_size: {config['num_inputs']} and output_size: {config['num_inputs']}!"
assert config['num_inputs'] * config['num_outputs'] + 1 <= C, \
f"Too small C: {C} for input_size: {config['num_inputs']} and output_size: {config['num_outputs']}!"
pop_nodes = np.full((config['pop_size'], N, 5), np.nan, dtype=np.float32)
pop_cons = np.full((config['pop_size'], C, 4), np.nan, dtype=np.float32)
input_idx = config['input_idx']
output_idx = config['output_idx']
pop_nodes[:, input_idx, 0] = input_idx
pop_nodes[:, output_idx, 0] = output_idx
# pop_nodes[:, output_idx, 1] = config['bias_init_mean']
pop_nodes[:, output_idx, 1] = np.random.normal(loc=config['bias_init_mean'], scale=config['bias_init_std'],
size=(config['pop_size'], 1))
pop_nodes[:, output_idx, 2] = np.random.normal(loc=config['response_init_mean'], scale=config['response_init_std'],
size=(config['pop_size'], 1))
pop_nodes[:, output_idx, 3] = np.random.choice(config['activation_options'], size=(config['pop_size'], 1))
pop_nodes[:, output_idx, 4] = np.random.choice(config['aggregation_options'], size=(config['pop_size'], 1))
grid_a, grid_b = np.meshgrid(input_idx, output_idx)
grid_a, grid_b = grid_a.flatten(), grid_b.flatten()
p = config['num_inputs'] * config['num_outputs']
pop_cons[:, :p, 0] = grid_a
pop_cons[:, :p, 1] = grid_b
pop_cons[:, :p, 2] = np.random.normal(loc=config['weight_init_mean'], scale=config['weight_init_std'],
size=(config['pop_size'], p))
pop_cons[:, :p, 3] = 1
return pop_nodes, pop_cons
@jit
def add_node(nodes: NDArray, cons: NDArray, new_key: int,
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[NDArray, NDArray]:
"""
Add a new node to the genome.
The new node will place at the first NaN row.
"""
exist_keys = nodes[:, 0]
idx = fetch_first(jnp.isnan(exist_keys))
nodes = nodes.at[idx].set(jnp.array([new_key, bias, response, act, agg]))
return nodes, cons
@jit
def delete_node(nodes: NDArray, cons: NDArray, node_key: int) -> Tuple[NDArray, NDArray]:
"""
Delete a node from the genome. Only delete the node, regardless of connections.
Delete the node by its key.
"""
node_keys = nodes[:, 0]
idx = fetch_first(node_keys == node_key)
return delete_node_by_idx(nodes, cons, idx)
@jit
def delete_node_by_idx(nodes: NDArray, cons: NDArray, idx: int) -> Tuple[NDArray, NDArray]:
"""
Delete a node from the genome. Only delete the node, regardless of connections.
Delete the node by its idx.
"""
nodes = nodes.at[idx].set(np.nan)
return nodes, cons
@jit
def add_connection(nodes: NDArray, cons: NDArray, i_key: int, o_key: int,
weight: float = 1.0, enabled: bool = True) -> Tuple[NDArray, NDArray]:
"""
Add a new connection to the genome.
The new connection will place at the first NaN row.
"""
con_keys = cons[:, 0]
idx = fetch_first(jnp.isnan(con_keys))
cons = cons.at[idx].set(jnp.array([i_key, o_key, weight, enabled]))
return nodes, cons
@jit
def delete_connection(nodes: NDArray, cons: NDArray, i_key: int, o_key: int) -> Tuple[NDArray, NDArray]:
"""
Delete a connection from the genome.
Delete the connection by its input and output node keys.
"""
idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
return delete_connection_by_idx(nodes, cons, idx)
@jit
def delete_connection_by_idx(nodes: NDArray, cons: NDArray, idx: int) -> Tuple[NDArray, NDArray]:
"""
Delete a connection from the genome.
Delete the connection by its idx.
"""
cons = cons.at[idx].set(np.nan)
return nodes, cons

View File

@@ -1,167 +0,0 @@
"""
Some graph algorithm implemented in jax.
Only used in feed-forward networks.
"""
import jax
from jax import jit, Array, numpy as jnp
from algorithms.neat.genome.utils import fetch_first, I_INT
@jit
def topological_sort(nodes: Array, connections: Array) -> Array:
"""
a jit-able version of topological_sort! that's crazy!
:param nodes: nodes array
:param connections: connections array
:return: topological sorted sequence
Example:
nodes = jnp.array([
[0],
[1],
[2],
[3]
])
connections = jnp.array([
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
],
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
]
])
topological_sort(nodes, connections) -> [0, 1, 2, 3]
"""
connections_enable = connections[1, :, :] == 1 # forward function. thus use enable
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0))
res = jnp.full(in_degree.shape, I_INT)
def cond_fun(carry):
res_, idx_, in_degree_ = carry
i = fetch_first(in_degree_ == 0.)
return i != I_INT
def body_func(carry):
res_, idx_, in_degree_ = carry
i = fetch_first(in_degree_ == 0.)
# add to res and flag it is already in it
res_ = res_.at[idx_].set(i)
in_degree_ = in_degree_.at[i].set(-1)
# decrease in_degree of all its children
children = connections_enable[i, :]
in_degree_ = jnp.where(children, in_degree_ - 1, in_degree_)
return res_, idx_ + 1, in_degree_
res, _, _ = jax.lax.while_loop(cond_fun, body_func, (res, 0, in_degree))
return res
@jit
def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array:
"""
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
:param nodes: JAX array
The array of nodes.
:param connections: JAX array
The array of connections.
:param from_idx: int
The index of the starting node.
:param to_idx: int
The index of the ending node.
:return: JAX array
An array indicating if there is a cycle caused by the new connection.
Example:
nodes = jnp.array([
[0],
[1],
[2],
[3]
])
connections = jnp.array([
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
],
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
]
])
check_cycles(nodes, connections, 3, 2) -> True
check_cycles(nodes, connections, 2, 3) -> False
check_cycles(nodes, connections, 0, 3) -> False
check_cycles(nodes, connections, 1, 0) -> False
"""
connections_enable = ~jnp.isnan(connections[0, :, :])
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
visited = jnp.full(nodes.shape[0], False)
new_visited = visited.at[to_idx].set(True)
def cond_func(carry):
visited_, new_visited_ = carry
end_cond1 = jnp.all(visited_ == new_visited_) # no new nodes been visited
end_cond2 = new_visited_[from_idx] # the starting node has been visited
return jnp.logical_not(end_cond1 | end_cond2)
def body_func(carry):
_, visited_ = carry
new_visited_ = jnp.dot(visited_, connections_enable)
new_visited_ = jnp.logical_or(visited_, new_visited_)
return visited_, new_visited_
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
return visited[from_idx]
if __name__ == '__main__':
nodes = jnp.array([
[0],
[1],
[2],
[3],
[jnp.nan]
])
connections = jnp.array([
[
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
],
[
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
]
]
)
print(topological_sort(nodes, connections))
print(check_cycles(nodes, connections, 3, 2))
print(check_cycles(nodes, connections, 2, 3))
print(check_cycles(nodes, connections, 0, 3))
print(check_cycles(nodes, connections, 1, 0))

View File

@@ -1,349 +0,0 @@
"""
Mutate a genome.
The calculation method is the same as the mutation operation in NEAT-python.
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.mutate
"""
from typing import Tuple, Dict
import jax
from jax import numpy as jnp, jit, Array
from .utils import fetch_random, fetch_first, I_INT, unflatten_connections
from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection
from .graph import check_cycles
@jit
def mutate(rand_key: Array, nodes: Array, connections: Array, new_node_key: int, jit_config: Dict):
"""
:param rand_key:
:param nodes: (N, 5)
:param connections: (2, N, N)
:param new_node_key:
:param jit_config:
:return:
"""
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
# structural mutations
# mutate add node
r = rand(r1)
aux_nodes, aux_connections = mutate_add_node(r1, nodes, connections, new_node_key, jit_config)
nodes = jnp.where(r < jit_config['node_add_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['node_add_prob'], aux_connections, connections)
# mutate add connection
r = rand(r2)
aux_nodes, aux_connections = mutate_add_connection(r3, nodes, connections, jit_config)
nodes = jnp.where(r < jit_config['conn_add_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['conn_add_prob'], aux_connections, connections)
# mutate delete node
r = rand(r3)
aux_nodes, aux_connections = mutate_delete_node(r2, nodes, connections, jit_config)
nodes = jnp.where(r < jit_config['node_delete_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['node_delete_prob'], aux_connections, connections)
# mutate delete connection
r = rand(r4)
aux_nodes, aux_connections = mutate_delete_connection(r4, nodes, connections)
nodes = jnp.where(r < jit_config['conn_delete_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['conn_delete_prob'], aux_connections, connections)
# value mutations
nodes, connections = mutate_values(rand_key, nodes, connections, jit_config)
return nodes, connections
def mutate_values(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
"""
Mutate values of nodes and connections.
Args:
rand_key: A random key for generating random values.
nodes: A 2D array representing nodes.
cons: A 3D array representing connections.
jit_config: A dict containing configuration for jit-able functions.
Returns:
A tuple containing mutated nodes and connections.
"""
k1, k2, k3, k4, k5, rand_key = jax.random.split(rand_key, num=6)
# bias
bias_new = mutate_float_values(k1, nodes[:, 1], jit_config['bias_init_mean'], jit_config['bias_init_std'],
jit_config['bias_mutate_power'], jit_config['bias_mutate_rate'],
jit_config['bias_replace_rate'])
# response
response_new = mutate_float_values(k2, nodes[:, 2], jit_config['response_init_mean'],
jit_config['response_init_std'], jit_config['response_mutate_power'],
jit_config['response_mutate_rate'], jit_config['response_replace_rate'])
# weight
weight_new = mutate_float_values(k3, cons[:, 2], jit_config['weight_init_mean'], jit_config['weight_init_std'],
jit_config['weight_mutate_power'], jit_config['weight_mutate_rate'],
jit_config['weight_replace_rate'])
# activation
act_new = mutate_int_values(k4, nodes[:, 3], jit_config['activation_options'],
jit_config['activation_replace_rate'])
# aggregation
agg_new = mutate_int_values(k5, nodes[:, 4], jit_config['aggregation_options'],
jit_config['aggregation_replace_rate'])
# enabled
r = jax.random.uniform(rand_key, cons[:, 3].shape)
enabled_new = jnp.where(r < jit_config['enable_mutate_rate'], 1 - cons[:, 3], cons[:, 3])
# merge
nodes = jnp.column_stack([nodes[:, 0], bias_new, response_new, act_new, agg_new])
cons = jnp.column_stack([cons[:, 0], cons[:, 1], weight_new, enabled_new])
return nodes, cons
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
"""
Mutate float values of a given array.
Args:
rand_key: A random key for generating random values.
old_vals: A 1D array of float values to be mutated.
mean: Mean of the values.
std: Standard deviation of the values.
mutate_strength: Strength of the mutation.
mutate_rate: Rate of the mutation.
replace_rate: Rate of the replacement.
Returns:
A mutated 1D array of float values.
"""
k1, k2, k3, rand_key = jax.random.split(rand_key, num=4)
noise = jax.random.normal(k1, old_vals.shape) * mutate_strength
replace = jax.random.normal(k2, old_vals.shape) * std + mean
r = jax.random.uniform(k3, old_vals.shape)
# default
new_vals = old_vals
# r in [0, mutate_rate), mutate
new_vals = jnp.where(r < mutate_rate, new_vals + noise, new_vals)
# r in [mutate_rate, mutate_rate + replace_rate), replace
new_vals = jnp.where(
(mutate_rate < r) & (r < mutate_rate + replace_rate),
replace + new_vals * 0.0, # in case of nan replace to values
new_vals
)
new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan)
return new_vals
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
"""
Mutate integer values (act, agg) of a given array.
Args:
rand_key: A random key for generating random values.
old_vals: A 1D array of integer values to be mutated.
val_list: List of the integer values.
replace_rate: Rate of the replacement.
Returns:
A mutated 1D array of integer values.
"""
k1, k2, rand_key = jax.random.split(rand_key, num=3)
replace_val = jax.random.choice(k1, val_list, old_vals.shape)
r = jax.random.uniform(k2, old_vals.shape)
new_vals = jnp.where(r < replace_rate, replace_val + old_vals * 0.0, old_vals) # in case of nan replace to values
return new_vals
def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: int,
jit_config: Dict) -> Tuple[Array, Array]:
"""
Randomly add a new node from splitting a connection.
:param rand_key:
:param new_node_key:
:param nodes:
:param cons:
:param jit_config:
:return:
"""
# randomly choose a connection
i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons)
def nothing(): # there is no connection to split
return nodes, cons
def successful_add_node():
# disable the connection
new_nodes, new_cons = nodes, cons
# set enable to false
new_cons = new_cons.at[idx, 3].set(False)
# add a new node
new_nodes, new_cons = add_node(new_nodes, new_cons, new_node_key, bias=0, response=1,
act=jit_config['activation_default'], agg=jit_config['aggregation_default'])
# add two new connections
w = new_cons[idx, 2]
new_nodes, new_cons = add_connection(new_nodes, new_cons, i_key, new_node_key, weight=1, enabled=True)
new_nodes, new_cons = add_connection(new_nodes, new_cons, new_node_key, o_key, weight=w, enabled=True)
return new_nodes, new_cons
# if from_idx == I_INT, that means no connection exist, do nothing
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_add_node)
return nodes, cons
# TODO: Do we really need to delete a node?
def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
"""
Randomly delete a node. Input and output nodes are not allowed to be deleted.
:param rand_key:
:param nodes:
:param cons:
:param jit_config:
:return:
"""
# randomly choose a node
key, idx = choice_node_key(rand_key, nodes, jit_config['input_idx'], jit_config['output_idx'],
allow_input_keys=False, allow_output_keys=False)
def nothing():
return nodes, cons
def successful_delete_node():
# delete the node
aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, idx)
# delete all connections
aux_cons = jnp.where(((aux_cons[:, 0] == key) | (aux_cons[:, 1] == key))[:, None],
jnp.nan, aux_cons)
return aux_nodes, aux_cons
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_delete_node)
return nodes, cons
def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
"""
Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks,
cycles are not allowed.
:param rand_key:
:param nodes:
:param cons:
:param jit_config:
:return:
"""
# randomly choose two nodes
k1, k2 = jax.random.split(rand_key, num=2)
i_key, from_idx = choice_node_key(k1, nodes, jit_config['input_idx'], jit_config['output_idx'],
allow_input_keys=True, allow_output_keys=True)
o_key, to_idx = choice_node_key(k2, nodes, jit_config['input_idx'], jit_config['output_idx'],
allow_input_keys=False, allow_output_keys=True)
con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
def successful():
new_nodes, new_cons = add_connection(nodes, cons, i_key, o_key, weight=1, enabled=True)
return new_nodes, new_cons
def already_exist():
new_cons = cons.at[con_idx, 3].set(True)
return nodes, new_cons
def cycle():
return nodes, cons
is_already_exist = con_idx != I_INT
u_cons = unflatten_connections(nodes, cons)
is_cycle = check_cycles(nodes, u_cons, from_idx, to_idx)
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful])
return nodes, cons
def mutate_delete_connection(rand_key: Array, nodes: Array, cons: Array):
"""
Randomly delete a connection.
:param rand_key:
:param nodes:
:param cons:
:return:
"""
# randomly choose a connection
i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons)
def nothing():
return nodes, cons
def successfully_delete_connection():
return delete_connection_by_idx(nodes, cons, idx)
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection)
return nodes, cons
def choice_node_key(rand_key: Array, nodes: Array,
input_keys: Array, output_keys: Array,
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
"""
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
:param rand_key:
:param nodes:
:param input_keys:
:param output_keys:
:param allow_input_keys:
:param allow_output_keys:
:return: return its key and position(idx)
"""
node_keys = nodes[:, 0]
mask = ~jnp.isnan(node_keys)
if not allow_input_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys))
if not allow_output_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys))
idx = fetch_random(rand_key, mask)
key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan)
return key, idx
def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[Array, Array, Array]:
"""
Randomly choose a connection key from the given connections.
:param rand_key:
:param nodes:
:param cons:
:return: i_key, o_key, idx
"""
idx = fetch_random(rand_key, ~jnp.isnan(cons[:, 0]))
i_key = jnp.where(idx != I_INT, cons[idx, 0], jnp.nan)
o_key = jnp.where(idx != I_INT, cons[idx, 1], jnp.nan)
return i_key, o_key, idx
def rand(rand_key):
return jax.random.uniform(rand_key, ())

View File

@@ -1,71 +0,0 @@
from functools import partial
import numpy as np
import jax
from jax import numpy as jnp, Array, jit, vmap
I_INT = np.iinfo(jnp.int32).max # infinite int
EMPTY_NODE = np.full((1, 5), jnp.nan)
EMPTY_CON = np.full((1, 4), jnp.nan)
@jit
def unflatten_connections(nodes: Array, cons: Array):
"""
transform the (C, 4) connections to (2, N, N)
:param nodes: (N, 5)
:param cons: (C, 4)
:return:
"""
N = nodes.shape[0]
node_keys = nodes[:, 0]
i_keys, o_keys = cons[:, 0], cons[:, 1]
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
res = jnp.full((2, N, N), jnp.nan)
# Is interesting that jax use clip when attach data in array
# however, it will do nothing set values in an array
res = res.at[0, i_idxs, o_idxs].set(cons[:, 2])
res = res.at[1, i_idxs, o_idxs].set(cons[:, 3])
return res
def key_to_indices(key, keys):
return fetch_first(key == keys)
@jit
def fetch_first(mask, default=I_INT) -> Array:
"""
fetch the first True index
:param mask: array of bool
:param default: the default value if no element satisfying the condition
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return default value
"""
idx = jnp.argmax(mask)
return jnp.where(mask[idx], idx, default)
@jit
def fetch_random(rand_key, mask, default=I_INT) -> Array:
"""
similar to fetch_first, but fetch a random True index
"""
true_cnt = jnp.sum(mask)
cumsum = jnp.cumsum(mask)
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
return fetch_first(mask, default)
@partial(jit, static_argnames=['reverse'])
def rank_elements(array, reverse=False):
"""
rank the element in the array.
if reverse is True, the rank is from small to large. default large to small
"""
if not reverse:
array = -array
return jnp.argsort(jnp.argsort(array))

View File

@@ -1,441 +0,0 @@
"""
Contains operations on the population: creating the next generation and population speciation.
The value tuple (P, N, C, S) is determined when the algorithm is initialized.
P: population size
N: maximum number of nodes in any genome
C: maximum number of connections in any genome
S: maximum number of species in NEAT
These arrays are used in the algorithm:
fitness: Array[(P,), float], the fitness of each individual
randkey: Array[2, uint], the random key
pop_nodes: Array[(P, N, 5), float], nodes part of the population. [key, bias, response, act, agg]
pop_cons: Array[(P, C, 4), float], connections part of the population. [in_node, out_node, weight, enabled]
species_info: Array[(S, 4), float], the information of each species. [key, best_score, last_update, members_count]
idx2species: Array[(P,), float], map the individual to its species keys
center_nodes: Array[(S, N, 5), float], the center nodes of each species
center_cons: Array[(S, C, 4), float], the center connections of each species
generation: int, the current generation
next_node_key: float, the next of the next node
next_species_key: float, the next of the next species
jit_config: Configer, the config used in jit-able functions
"""
# TODO: Complete python doc
import numpy as np
import jax
from jax import jit, vmap, Array, numpy as jnp
from .genome import initialize_genomes, distance, mutate, crossover, fetch_first, rank_elements
def initialize(config):
"""
initialize the states of NEAT.
"""
P = config['pop_size']
N = config['maximum_nodes']
C = config['maximum_connections']
S = config['maximum_species']
randkey = jax.random.PRNGKey(config['random_seed'])
np.random.seed(config['random_seed'])
pop_nodes, pop_cons = initialize_genomes(N, C, config)
species_info = np.full((S, 4), np.nan, dtype=np.float32)
species_info[0, :] = 0, -np.inf, 0, P
idx2species = np.zeros(P, dtype=np.float32)
center_nodes = np.full((S, N, 5), np.nan, dtype=np.float32)
center_cons = np.full((S, C, 4), np.nan, dtype=np.float32)
center_nodes[0, :, :] = pop_nodes[0, :, :]
center_cons[0, :, :] = pop_cons[0, :, :]
generation = np.asarray(0, dtype=np.int32)
next_node_key = np.asarray(config['num_inputs'] + config['num_outputs'], dtype=np.float32)
next_species_key = np.asarray(1, dtype=np.float32)
return jax.device_put([
randkey,
pop_nodes,
pop_cons,
species_info,
idx2species,
center_nodes,
center_cons,
generation,
next_node_key,
next_species_key,
])
@jit
def tell(fitness,
randkey,
pop_nodes,
pop_cons,
species_info,
idx2species,
center_nodes,
center_cons,
generation,
next_node_key,
next_species_key,
jit_config):
"""
Main update function in NEAT.
"""
generation += 1
k1, k2, randkey = jax.random.split(randkey, 3)
species_info, center_nodes, center_cons, winner, loser, elite_mask = \
update_species(k1, fitness, species_info, idx2species, center_nodes,
center_cons, generation, jit_config)
pop_nodes, pop_cons, next_node_key = create_next_generation(k2, pop_nodes, pop_cons, winner, loser,
elite_mask, next_node_key, jit_config)
idx2species, center_nodes, center_cons, species_info, next_species_key = speciate(
pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, next_species_key, jit_config)
return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation, next_node_key, next_species_key
def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config):
"""
args:
randkey: random key
fitness: Array[(pop_size,), float], the fitness of each individual
species_keys: Array[(species_size, 4), float], the information of each species
[species_key, best_score, last_update, members_count]
idx2species: Array[(pop_size,), int], map the individual to its species
center_nodes: Array[(species_size, N, 4), float], the center nodes of each species
center_cons: Array[(species_size, C, 4), float], the center connections of each species
generation: int, current generation
jit_config: Dict, the configuration of jit functions
"""
# update the fitness of each species
species_fitness = update_species_fitness(species_info, idx2species, fitness)
# stagnation species
species_fitness, species_info, center_nodes, center_cons = \
stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config)
# sort species_info by their fitness. (push nan to the end)
sort_indices = jnp.argsort(species_fitness)[::-1]
species_info = species_info[sort_indices]
center_nodes, center_cons = center_nodes[sort_indices], center_cons[sort_indices]
# decide the number of members of each species by their fitness
spawn_number = cal_spawn_numbers(species_info, jit_config)
# jax.debug.print("spawn_number: {}", spawn_number)
# crossover info
winner, loser, elite_mask = \
create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config)
return species_info, center_nodes, center_cons, winner, loser, elite_mask
def update_species_fitness(species_info, idx2species, fitness):
"""
obtain the fitness of the species by the fitness of each individual.
use max criterion.
"""
def aux_func(idx):
species_key = species_info[idx, 0]
s_fitness = jnp.where(idx2species == species_key, fitness, -jnp.inf)
f = jnp.max(s_fitness)
return f
return vmap(aux_func)(jnp.arange(species_info.shape[0]))
def stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config):
"""
stagnation species.
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
elitism species never stagnation
"""
def aux_func(idx):
s_fitness = species_fitness[idx]
species_key, best_score, last_update, members_count = species_info[idx]
st = (s_fitness <= best_score) & (generation - last_update > jit_config['max_stagnation'])
last_update = jnp.where(s_fitness > best_score, generation, last_update)
best_score = jnp.where(s_fitness > best_score, s_fitness, best_score)
# stagnation condition
return st, jnp.array([species_key, best_score, last_update, members_count])
spe_st, species_info = vmap(aux_func)(jnp.arange(species_info.shape[0]))
# elite species will not be stagnation
species_rank = rank_elements(species_fitness)
spe_st = jnp.where(species_rank < jit_config['species_elitism'], False, spe_st) # elitism never stagnation
# set stagnation species to nan
species_info = jnp.where(spe_st[:, None], jnp.nan, species_info)
center_nodes = jnp.where(spe_st[:, None, None], jnp.nan, center_nodes)
center_cons = jnp.where(spe_st[:, None, None], jnp.nan, center_cons)
species_fitness = jnp.where(spe_st, -jnp.inf, species_fitness)
return species_fitness, species_info, center_nodes, center_cons
def cal_spawn_numbers(species_info, jit_config):
"""
decide the number of members of each species by their fitness rank.
the species with higher fitness will have more members
Linear ranking selection
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
"""
is_species_valid = ~jnp.isnan(species_info[:, 0])
valid_species_num = jnp.sum(is_species_valid)
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
rank_score = valid_species_num - jnp.arange(species_info.shape[0]) # obtain [3, 2, 1]
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
target_spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']) # calculate member
# jax.debug.print("denominator: {}, spawn_number_rate: {}, target_spawn_number: {}", denominator, spawn_number_rate, target_spawn_number)
# Avoid too much variation of numbers in a species
previous_size = species_info[:, 3].astype(jnp.int32)
spawn_number = previous_size + (target_spawn_number - previous_size) * jit_config['spawn_number_move_rate']
# jax.debug.print("previous_size: {}, spawn_number: {}", previous_size, spawn_number)
spawn_number = spawn_number.astype(jnp.int32)
# spawn_number = target_spawn_number.astype(jnp.int32)
# must control the sum of spawn_number to be equal to pop_size
error = jit_config['pop_size'] - jnp.sum(spawn_number)
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
return spawn_number
def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config):
species_size = species_info.shape[0]
pop_size = fitness.shape[0]
s_idx = jnp.arange(species_size)
p_idx = jnp.arange(pop_size)
# def aux_func(key, idx):
def aux_func(key, idx):
members = idx2species == species_info[idx, 0]
members_num = jnp.sum(members)
members_fitness = jnp.where(members, fitness, -jnp.inf)
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
elite_size = jit_config['genome_elitism']
survive_size = jnp.floor(jit_config['survival_threshold'] * members_num).astype(jnp.int32)
select_pro = (p_idx < survive_size) / survive_size
fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro)
# elite
fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa)
ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma)
elite = jnp.where(p_idx < elite_size, True, False)
return fa, ma, elite
# fas, mas, elites = jax.lax.max(aux_func, (jax.random.split(randkey, species_size), s_idx))
fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx)
spawn_number_cum = jnp.cumsum(spawn_number)
def aux_func(idx):
loc = jnp.argmax(idx < spawn_number_cum)
# elite genomes are at the beginning of the species
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species]
part1, part2, elite_mask = vmap(aux_func)(p_idx)
is_part1_win = fitness[part1] >= fitness[part2]
winner = jnp.where(is_part1_win, part1, part2)
loser = jnp.where(is_part1_win, part2, part1)
return winner, loser, elite_mask
def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, next_node_key, jit_config):
# prepare random keys
pop_size = pop_nodes.shape[0]
new_node_keys = jnp.arange(pop_size) + next_node_key
k1, k2 = jax.random.split(rand_key, 2)
crossover_rand_keys = jax.random.split(k1, pop_size)
mutate_rand_keys = jax.random.split(k2, pop_size)
# batch crossover
wpn, wpc = pop_nodes[winner], pop_cons[winner] # winner pop nodes, winner pop connections
lpn, lpc = pop_nodes[loser], pop_cons[loser] # loser pop nodes, loser pop connections
npn, npc = vmap(crossover)(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
# batch mutation
mutate_func = vmap(mutate, in_axes=(0, 0, 0, 0, None))
m_npn, m_npc = mutate_func(mutate_rand_keys, npn, npc, new_node_keys, jit_config) # mutate_new_pop_nodes
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
# update next node key
all_nodes_keys = pop_nodes[:, :, 0]
max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys))
next_node_key = max_node_key + 1
return pop_nodes, pop_cons, next_node_key
def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, next_species_key, jit_config):
"""
args:
pop_nodes: (pop_size, N, 5)
pop_cons: (pop_size, C, 4)
spe_center_nodes: (species_size, N, 5)
spe_center_cons: (species_size, C, 4)
"""
pop_size, species_size = pop_nodes.shape[0], center_nodes.shape[0]
# prepare distance functions
o2p_distance_func = vmap(distance, in_axes=(None, None, 0, 0, None)) # one to population
# idx to specie key
idx2specie = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species
# the distance between genomes to its center genomes
o2c_distances = jnp.full((pop_size,), jnp.inf)
# step 1: find new centers
def cond_func(carry):
i, i2s, cn, cc, o2c = carry
species_key = species_info[i, 0]
# jax.debug.print("{}, {}", i, species_key)
return (i < species_size) & (~jnp.isnan(species_key)) # current species is existing
def body_func(carry):
i, i2s, cn, cc, o2c = carry
distances = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config)
# find the closest one
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
# jax.debug.print("closest_idx: {}", closest_idx)
i2s = i2s.at[closest_idx].set(species_info[i, 0])
cn = cn.at[i].set(pop_nodes[closest_idx])
cc = cc.at[i].set(pop_cons[closest_idx])
# the genome with closest_idx will become the new center, thus its distance to center is 0.
o2c = o2c.at[closest_idx].set(0)
return i + 1, i2s, cn, cc, o2c
_, idx2specie, center_nodes, center_cons, o2c_distances = \
jax.lax.while_loop(cond_func, body_func, (0, idx2specie, center_nodes, center_cons, o2c_distances))
# jax.debug.print("species_info: \n{}", species_info)
# jax.debug.print("idx2specie: \n{}", idx2specie)
# part 2: assign members to each species
def cond_func(carry):
i, i2s, cn, cc, si, o2c, nsk = carry # si is short for species_info, nsk is short for next_species_key
# jax.debug.print("i:\n{}\ni2s:\n{}\nsi:\n{}", i, i2s, si)
current_species_existed = ~jnp.isnan(si[i, 0])
not_all_assigned = jnp.any(jnp.isnan(i2s))
not_reach_species_upper_bounds = i < species_size
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
def body_func(carry):
i, i2s, cn, cc, si, o2c, nsk = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
_, i2s, scn, scc, si, o2c, nsk = jax.lax.cond(
jnp.isnan(si[i, 0]), # whether the current species is existing or not
create_new_species, # if not existing, create a new specie
update_exist_specie, # if existing, update the specie
(i, i2s, cn, cc, si, o2c, nsk)
)
return i + 1, i2s, scn, scc, si, o2c, nsk
def create_new_species(carry):
i, i2s, cn, cc, si, o2c, nsk = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(jnp.isnan(i2s))
# assign it to the new species
# [key, best score, last update generation, members_count]
si = si.at[i].set(jnp.array([nsk, -jnp.inf, generation, 0]))
i2s = i2s.at[idx].set(nsk)
o2c = o2c.at[idx].set(0)
# update center genomes
cn = cn.at[i].set(pop_nodes[idx])
cc = cc.at[i].set(pop_cons[idx])
i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c))
# when a new species is created, it needs to be updated, thus do not change i
return i + 1, i2s, cn, cc, si, o2c, nsk + 1 # change to next new speciate key
def update_exist_specie(carry):
i, i2s, cn, cc, si, o2c, nsk = carry
i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c))
# turn to next species
return i + 1, i2s, cn, cc, si, o2c, nsk
def speciate_by_threshold(carry):
i, i2s, cn, cc, si, o2c = carry
# distance between such center genome and ppo genomes
o2p_distance = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config)
close_enough_mask = o2p_distance < jit_config['compatibility_threshold']
# when a genome is not assigned or the distance between its current center is bigger than this center
cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
# jax.debug.print("{}", o2p_distance)
mask = close_enough_mask & cacheable_mask
# update species info
i2s = jnp.where(mask, si[i, 0], i2s)
# update distance between centers
o2c = jnp.where(mask, o2p_distance, o2c)
return i2s, o2c
# update idx2specie
_, idx2specie, center_nodes, center_cons, species_info, _, next_species_key = jax.lax.while_loop(
cond_func,
body_func,
(0, idx2specie, center_nodes, center_cons, species_info, o2c_distances, next_species_key)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition can only happen when the number of species is reached species upper bounds
idx2specie = jnp.where(jnp.isnan(idx2specie), species_info[-1, 0], idx2specie)
# update members count
def count_members(idx):
key = species_info[idx, 0]
count = jnp.sum(idx2specie == key)
count = jnp.where(jnp.isnan(key), jnp.nan, count)
return count
species_member_counts = vmap(count_members)(jnp.arange(species_size))
species_info = species_info.at[:, 3].set(species_member_counts)
return idx2specie, center_nodes, center_cons, species_info, next_species_key
def argmin_with_mask(arr: Array, mask: Array) -> Array:
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx

View File

@@ -1,112 +0,0 @@
import jax
import numpy as np
class Genome:
def __init__(self, nodes, cons, config):
self.config = config
self.nodes, self.cons = array2object(nodes, cons, config)
if config['renumber_nodes']:
self.renumber()
def __repr__(self):
return f'Genome(\n' \
f'\tinput_keys: {self.config["input_idx"]}, \n' \
f'\toutput_keys: {self.config["output_idx"]}, \n' \
f'\tnodes: \n\t\t' \
f'{self.repr_nodes()} \n' \
f'\tconnections: \n\t\t' \
f'{self.repr_conns()} \n)'
def repr_nodes(self):
nodes_info = []
for key, value in self.nodes.items():
bias, response, act, agg = value
act_func = self.config['activation_option_names'][int(act)] if act is not None else None
agg_func = self.config['aggregation_option_names'][int(agg)] if agg is not None else None
s = f"{key}: (bias: {bias}, response: {response}, act: {act_func}, agg: {agg_func})"
nodes_info.append(s)
return ',\n\t\t'.join(nodes_info)
def repr_conns(self):
conns_info = []
for key, value in self.cons.items():
weight, enabled = value
s = f"{key}: (weight: {weight}, enabled: {enabled})"
conns_info.append(s)
return ',\n\t\t'.join(conns_info)
def renumber(self):
nodes2new_nodes = {}
new_id = len(self.config['input_idx']) + len(self.config['output_idx'])
for key in self.nodes.keys():
if key in self.config['input_idx'] or key in self.config['output_idx']:
nodes2new_nodes[key] = key
else:
nodes2new_nodes[key] = new_id
new_id += 1
new_nodes, new_cons = {}, {}
for key, value in self.nodes.items():
new_nodes[nodes2new_nodes[key]] = value
for key, value in self.cons.items():
i_key, o_key = key
new_cons[(nodes2new_nodes[i_key], nodes2new_nodes[o_key])] = value
self.nodes = new_nodes
self.cons = new_cons
def array2object(nodes, cons, config):
"""
Convert a genome from array to dict.
:param nodes: (N, 5)
:param cons: (C, 4)
:return: nodes_dict[key: (bias, response, act, agg)], cons_dict[(i_key, o_key): (weight, enabled)]
"""
nodes, cons = jax.device_get((nodes, cons))
# update nodes_dict
nodes_dict = {}
for i, node in enumerate(nodes):
if np.isnan(node[0]):
continue
key = int(node[0])
assert key not in nodes_dict, f"Duplicate node key: {key}!"
if key in config['input_idx']:
assert np.all(np.isnan(node[1:])), f"Input node {key} must has None bias, response, act, or agg!"
nodes_dict[key] = (None,) * 4
else:
assert np.all(
~np.isnan(node[1:])), f"Normal node {key} must has non-None bias, response, act, or agg!"
bias = node[1]
response = node[2]
act = node[3]
agg = node[4]
nodes_dict[key] = (bias, response, act, agg)
# check nodes_dict
for i in config['input_idx']:
assert i in nodes_dict, f"Input node {i} not found in nodes_dict!"
for o in config['output_idx']:
assert o in nodes_dict, f"Output node {o} not found in nodes_dict!"
# update connections
cons_dict = {}
for i, con in enumerate(cons):
if np.all(np.isnan(con)):
pass
elif np.all(~np.isnan(con)):
i_key = int(con[0])
o_key = int(con[1])
if (i_key, o_key) in cons_dict:
assert False, f"Duplicate connection: {(i_key, o_key)}!"
assert i_key in nodes_dict, f"Input node {i_key} not found in nodes_dict!"
assert o_key in nodes_dict, f"Output node {o_key} not found in nodes_dict!"
weight = con[2]
enabled = (con[3] == 1)
cons_dict[(i_key, o_key)] = (weight, enabled)
else:
assert False, f"Connection {i} must has all None or all non-None!"
return nodes_dict, cons_dict

View File

@@ -1 +0,0 @@
from .configer import Configer

View File

@@ -1,118 +0,0 @@
import os
import warnings
import configparser
import numpy as np
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
jit_config_keys = [
"input_idx",
"output_idx",
"compatibility_disjoint",
"compatibility_weight",
"conn_add_prob",
"conn_add_trials",
"conn_delete_prob",
"node_add_prob",
"node_delete_prob",
"compatibility_threshold",
"bias_init_mean",
"bias_init_std",
"bias_mutate_power",
"bias_mutate_rate",
"bias_replace_rate",
"response_init_mean",
"response_init_std",
"response_mutate_power",
"response_mutate_rate",
"response_replace_rate",
"activation_default",
"activation_options",
"activation_replace_rate",
"aggregation_default",
"aggregation_options",
"aggregation_replace_rate",
"weight_init_mean",
"weight_init_std",
"weight_mutate_power",
"weight_mutate_rate",
"weight_replace_rate",
"enable_mutate_rate",
"max_stagnation",
"pop_size",
"genome_elitism",
"survival_threshold",
"species_elitism",
"spawn_number_move_rate"
]
class Configer:
@classmethod
def __load_default_config(cls):
par_dir = os.path.dirname(os.path.abspath(__file__))
default_config_path = os.path.join(par_dir, "default_config.ini")
return cls.__load_config(default_config_path)
@classmethod
def __load_config(cls, config_path):
c = configparser.ConfigParser()
c.read(config_path)
config = {}
for section in c.sections():
for key, value in c.items(section):
config[key] = eval(value)
return config
@classmethod
def __check_redundant_config(cls, default_config, config):
for key in config:
if key not in default_config:
warnings.warn(f"Redundant config: {key} in {config.name}")
@classmethod
def __complete_config(cls, default_config, config):
for key in default_config:
if key not in config:
config[key] = default_config[key]
@classmethod
def load_config(cls, config_path=None):
default_config = cls.__load_default_config()
if config_path is None:
config = {}
elif not os.path.exists(config_path):
warnings.warn(f"config file {config_path} not exist!")
config = {}
else:
config = cls.__load_config(config_path)
cls.__check_redundant_config(default_config, config)
cls.__complete_config(default_config, config)
cls.refactor_activation(config)
cls.refactor_aggregation(config)
config['input_idx'] = np.arange(config['num_inputs'])
config['output_idx'] = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs'])
return config
@classmethod
def refactor_activation(cls, config):
config['activation_default'] = 0
config['activation_options'] = np.arange(len(config['activation_option_names']))
@classmethod
def refactor_aggregation(cls, config):
config['aggregation_default'] = 0
config['aggregation_options'] = np.arange(len(config['aggregation_option_names']))
@classmethod
def create_jit_config(cls, config):
jit_config = {k: config[k] for k in jit_config_keys}
return jit_config

View File

@@ -1,70 +0,0 @@
[basic]
num_inputs = 2
num_outputs = 1
maximum_nodes = 50
maximum_connections = 50
maximum_species = 10
forward_way = "pop"
batch_size = 4
random_seed = 0
[population]
fitness_threshold = 3.99999
generation_limit = 1000
fitness_criterion = "max"
pop_size = 10000
[genome]
compatibility_disjoint = 1.0
compatibility_weight = 0.5
conn_add_prob = 0.4
conn_add_trials = 1
conn_delete_prob = 0.4
node_add_prob = 0.2
node_delete_prob = 0.2
[species]
compatibility_threshold = 3.0
species_elitism = 2
max_stagnation = 15
genome_elitism = 2
survival_threshold = 0.2
min_species_size = 1
spawn_number_move_rate = 0.5
[gene-bias]
bias_init_mean = 0.0
bias_init_std = 1.0
bias_mutate_power = 0.5
bias_mutate_rate = 0.7
bias_replace_rate = 0.1
[gene-response]
response_init_mean = 1.0
response_init_std = 0.0
response_mutate_power = 0.0
response_mutate_rate = 0.0
response_replace_rate = 0.0
[gene-activation]
activation_default = "sigmoid"
activation_option_names = ["sigmoid"]
activation_replace_rate = 0.0
[gene-aggregation]
aggregation_default = "sum"
aggregation_option_names = ["sum"]
aggregation_replace_rate = 0.0
[gene-weight]
weight_init_mean = 0.0
weight_init_std = 1.0
weight_mutate_power = 0.5
weight_mutate_rate = 0.8
weight_replace_rate = 0.1
[gene-enable]
enable_mutate_rate = 0.01
[visualize]
renumber_nodes = True

View File

@@ -1,2 +0,0 @@
from .neat import NEAT
from .gym_no_distribution import Gym

View File

@@ -1,81 +0,0 @@
from typing import Callable
import gym
import jax
import jax.numpy as jnp
import numpy as np
from evox import Problem, State
class Gym(Problem):
def __init__(
self,
pop_size: int,
policy: Callable,
env_name: str = "CartPole-v1",
env_options: dict = None,
batch_policy: bool = True,
):
self.pop_size = pop_size
self.env_name = env_name
self.policy = policy
self.env_options = env_options or {}
self.batch_policy = batch_policy
assert batch_policy, "Only batch policy is supported for now"
self.envs = [gym.make(env_name, **self.env_options) for _ in range(self.pop_size)]
super().__init__()
def setup(self, key):
return State(key=key)
def evaluate(self, state, pop):
key = state.key
# key, subkey = jax.random.split(state.key)
# generate a list of seeds for gym
# seeds = jax.random.randint(
# subkey, (self.pop_size,), 0, jnp.iinfo(jnp.int32).max
# )
# currently use fixed seed for debugging
seeds = jax.random.randint(
key, (self.pop_size,), 0, jnp.iinfo(jnp.int32).max
)
seeds = seeds.tolist() # seed must be a python int, not numpy array
fitnesses = self.__rollout(seeds, pop)
print("fitnesses info: ")
print(f"max: {np.max(fitnesses)}, min: {np.min(fitnesses)}, mean: {np.mean(fitnesses)}, std: {np.std(fitnesses)}")
# evox uses negative fitness for minimization
return -fitnesses, State(key=key)
def __rollout(self, seeds, pop):
observations = [env.reset(seed=seed) for env, seed in zip(self.envs, seeds)]
terminates, truncates = np.zeros((2, self.pop_size), dtype=bool)
fitnesses, rewards = np.zeros((2, self.pop_size))
while not np.all(terminates | truncates):
observations = np.asarray(observations)
actions = self.policy(pop, observations)
actions = jax.device_get(actions)
for i, (action, terminate, truncate, env) in enumerate(zip(actions, terminates, truncates, self.envs)):
if terminate | truncate:
observation = np.zeros(env.observation_space.shape)
reward = 0
else:
observation, reward, terminate, truncate, info = env.step(action)
observations[i] = observation
rewards[i] = reward
terminates[i] = terminate
truncates[i] = truncate
fitnesses += rewards
return fitnesses

View File

@@ -1,91 +0,0 @@
import jax.numpy as jnp
import evox
from algorithms import neat
from configs import Configer
@evox.jit_class
class NEAT(evox.Algorithm):
def __init__(self, config):
self.config = config # global config
self.jit_config = Configer.create_jit_config(config)
(
self.randkey,
self.pop_nodes,
self.pop_cons,
self.species_info,
self.idx2species,
self.center_nodes,
self.center_cons,
self.generation,
self.next_node_key,
self.next_species_key,
) = neat.initialize(config)
super().__init__()
def setup(self, key):
return evox.State(
randkey=self.randkey,
pop_nodes=self.pop_nodes,
pop_cons=self.pop_cons,
species_info=self.species_info,
idx2species=self.idx2species,
center_nodes=self.center_nodes,
center_cons=self.center_cons,
generation=self.generation,
next_node_key=self.next_node_key,
next_species_key=self.next_species_key,
jit_config=self.jit_config
)
def ask(self, state):
flatten_pop_nodes = state.pop_nodes.flatten()
flatten_pop_cons = state.pop_cons.flatten()
pop = jnp.concatenate([flatten_pop_nodes, flatten_pop_cons])
return pop, state
def tell(self, state, fitness):
# evox is a minimization framework, so we need to negate the fitness
fitness = -fitness
(
randkey,
pop_nodes,
pop_cons,
species_info,
idx2species,
center_nodes,
center_cons,
generation,
next_node_key,
next_species_key
) = neat.tell(
fitness,
state.randkey,
state.pop_nodes,
state.pop_cons,
state.species_info,
state.idx2species,
state.center_nodes,
state.center_cons,
state.generation,
state.next_node_key,
state.next_species_key,
state.jit_config
)
return evox.State(
randkey=randkey,
pop_nodes=pop_nodes,
pop_cons=pop_cons,
species_info=species_info,
idx2species=idx2species,
center_nodes=center_nodes,
center_cons=center_cons,
generation=generation,
next_node_key=next_node_key,
next_species_key=next_species_key,
jit_config=state.jit_config
)

View File

@@ -1,115 +0,0 @@
import pickle
import jax
from jax import numpy as jnp, jit, vmap
import numpy as np
from configs import Configer
from algorithms.neat import initialize_genomes
from algorithms.neat import tell
from algorithms.neat import unflatten_connections, topological_sort, create_forward_function
jax.config.update("jax_disable_jit", True)
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
def get_fitnesses(pop_nodes, pop_cons, pop_unflatten_connections, pop_topological_sort, forward_func):
u_pop_cons = pop_unflatten_connections(pop_nodes, pop_cons)
pop_seqs = pop_topological_sort(pop_nodes, u_pop_cons)
func = lambda x: forward_func(x, pop_seqs, pop_nodes, u_pop_cons)
return evaluate(func)
def equal(ar1, ar2):
if ar1.shape != ar2.shape:
return False
nan_mask1 = jnp.isnan(ar1)
nan_mask2 = jnp.isnan(ar2)
return jnp.all((ar1 == ar2) | (nan_mask1 & nan_mask2))
def main():
# initialize
config = Configer.load_config("xor.ini")
jit_config = Configer.create_jit_config(config) # config used in jit-able functions
P = config['pop_size']
N = config['init_maximum_nodes']
C = config['init_maximum_connections']
S = config['init_maximum_species']
randkey = jax.random.PRNGKey(6)
np.random.seed(6)
pop_nodes, pop_cons = initialize_genomes(N, C, config)
species_info = np.full((S, 4), np.nan)
species_info[0, :] = 0, -np.inf, 0, P
idx2species = np.zeros(P, dtype=np.float32)
center_nodes = np.full((S, N, 5), np.nan)
center_cons = np.full((S, C, 4), np.nan)
center_nodes[0, :, :] = pop_nodes[0, :, :]
center_cons[0, :, :] = pop_cons[0, :, :]
generation = 0
pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons = jax.device_put(
[pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons])
pop_unflatten_connections = jit(vmap(unflatten_connections))
pop_topological_sort = jit(vmap(topological_sort))
forward = create_forward_function(config)
while True:
fitness = get_fitnesses(pop_nodes, pop_cons, pop_unflatten_connections, pop_topological_sort, forward)
last_max = np.max(fitness)
info = [fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation,
jit_config]
with open('list.pkl', 'wb') as f:
# 使用pickle模块的dump函数来保存list
pickle.dump(info, f)
randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation = tell(
fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation,
jit_config)
fitness = get_fitnesses(pop_nodes, pop_cons, pop_unflatten_connections, pop_topological_sort, forward)
current_max = np.max(fitness)
print(last_max, current_max)
assert current_max >= last_max, f"current_max: {current_max}, last_max: {last_max}"
if __name__ == '__main__':
# main()
config = Configer.load_config("xor.ini")
pop_unflatten_connections = jit(vmap(unflatten_connections))
pop_topological_sort = jit(vmap(topological_sort))
forward = create_forward_function(config)
with open('list.pkl', 'rb') as f:
# 使用pickle模块的dump函数来保存list
fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, i, jit_config = pickle.load(
f)
print(np.max(fitness))
randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, _ = tell(
fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, i,
jit_config)
fitness = get_fitnesses(pop_nodes, pop_cons, pop_unflatten_connections, pop_topological_sort, forward)
print(np.max(fitness))

View File

@@ -1,22 +0,0 @@
[basic]
num_inputs = 6
num_outputs = 3
maximum_nodes = 50
maximum_connections = 50
maximum_species = 10
forward_way = "single"
random_seed = 42
[population]
pop_size = 100
[gene-activation]
activation_default = "sigmoid"
activation_option_names = ['sigmoid', 'tanh', 'sin', 'gauss', 'relu', 'identity', 'inv', 'log', 'exp', 'abs', 'hat', 'square']
activation_replace_rate = 0.1
[gene-aggregation]
aggregation_default = "sum"
aggregation_option_names = ['sum', 'product', 'max', 'min', 'maxabs', 'median', 'mean']
aggregation_replace_rate = 0.1

View File

@@ -1,63 +0,0 @@
import evox
import jax
from jax import jit, vmap, numpy as jnp
from configs import Configer
from algorithms.neat import create_forward_function, topological_sort, unflatten_connections
from evox_adaptor import NEAT, Gym
if __name__ == '__main__':
batch_policy = True
key = jax.random.PRNGKey(42)
monitor = evox.monitors.StdSOMonitor()
neat_config = Configer.load_config('acrobot.ini')
origin_forward_func = create_forward_function(neat_config)
def neat_transform(pop):
P = neat_config['pop_size']
N = neat_config['maximum_nodes']
C = neat_config['maximum_connections']
pop_nodes = pop[:P * N * 5].reshape((P, N, 5))
pop_cons = pop[P * N * 5:].reshape((P, C, 4))
u_pop_cons = vmap(unflatten_connections)(pop_nodes, pop_cons)
pop_seqs = vmap(topological_sort)(pop_nodes, u_pop_cons)
return pop_seqs, pop_nodes, u_pop_cons
# special policy for mountain car
def neat_forward(genome, x):
res = origin_forward_func(x, *genome)
out = jnp.argmax(res) # {0, 1, 2}
return out
forward_func = lambda pop, x: origin_forward_func(x, *pop)
problem = Gym(
policy=jit(vmap(neat_forward)),
env_name="Acrobot-v1",
env_options={"new_step_api": True},
pop_size=100,
)
# create a pipeline
pipeline = evox.pipelines.StdPipeline(
algorithm=NEAT(neat_config),
problem=problem,
pop_transform=jit(neat_transform),
fitness_transform=monitor.record_fit,
)
# init the pipeline
state = pipeline.init(key)
# run the pipeline for 10 steps
for i in range(30):
state = pipeline.step(state)
print(i, monitor.get_min_fitness())
# obtain -62.0
min_fitness = monitor.get_min_fitness()
print(min_fitness)

View File

@@ -1,22 +0,0 @@
[basic]
num_inputs = 24
num_outputs = 4
maximum_nodes = 100
maximum_connections = 200
maximum_species = 10
forward_way = "single"
random_seed = 42
[population]
pop_size = 100
[gene-activation]
activation_default = "sigmoid"
activation_option_names = ['sigmoid', 'tanh', 'sin', 'gauss', 'relu', 'identity', 'inv', 'log', 'exp', 'abs', 'hat', 'square']
activation_replace_rate = 0.1
[gene-aggregation]
aggregation_default = "sum"
aggregation_option_names = ['sum', 'product', 'max', 'min', 'maxabs', 'median', 'mean']
aggregation_replace_rate = 0.1

View File

@@ -1,62 +0,0 @@
import evox
import jax
from jax import jit, vmap, numpy as jnp
from configs import Configer
from algorithms.neat import create_forward_function, topological_sort, unflatten_connections
from evox_adaptor import NEAT, Gym
if __name__ == '__main__':
batch_policy = True
key = jax.random.PRNGKey(42)
monitor = evox.monitors.StdSOMonitor()
neat_config = Configer.load_config('bipedalwalker.ini')
origin_forward_func = create_forward_function(neat_config)
def neat_transform(pop):
P = neat_config['pop_size']
N = neat_config['maximum_nodes']
C = neat_config['maximum_connections']
pop_nodes = pop[:P * N * 5].reshape((P, N, 5))
pop_cons = pop[P * N * 5:].reshape((P, C, 4))
u_pop_cons = vmap(unflatten_connections)(pop_nodes, pop_cons)
pop_seqs = vmap(topological_sort)(pop_nodes, u_pop_cons)
return pop_seqs, pop_nodes, u_pop_cons
# special policy for mountain car
def neat_forward(genome, x):
res = origin_forward_func(x, *genome)
out = jnp.tanh(res) # (-1, 1)
return out
forward_func = lambda pop, x: origin_forward_func(x, *pop)
problem = Gym(
policy=jit(vmap(neat_forward)),
env_name="BipedalWalker-v3",
pop_size=100,
)
# create a pipeline
pipeline = evox.pipelines.StdPipeline(
algorithm=NEAT(neat_config),
problem=problem,
pop_transform=jit(neat_transform),
fitness_transform=monitor.record_fit,
)
# init the pipeline
state = pipeline.init(key)
# run the pipeline for 10 steps
for i in range(30):
state = pipeline.step(state)
print(i, monitor.get_min_fitness())
# obtain 98.91529684268514
min_fitness = monitor.get_min_fitness()
print(min_fitness)

View File

@@ -1,11 +0,0 @@
[basic]
num_inputs = 4
num_outputs = 1
maximum_nodes = 50
maximum_connections = 50
maximum_species = 10
forward_way = "single"
random_seed = 42
[population]
pop_size = 40

View File

@@ -1,63 +0,0 @@
import evox
import jax
from jax import jit, vmap, numpy as jnp
from configs import Configer
from algorithms.neat import create_forward_function, topological_sort, unflatten_connections
from evox_adaptor import NEAT, Gym
if __name__ == '__main__':
batch_policy = True
key = jax.random.PRNGKey(42)
monitor = evox.monitors.StdSOMonitor()
neat_config = Configer.load_config('cartpole.ini')
origin_forward_func = create_forward_function(neat_config)
def neat_transform(pop):
P = neat_config['pop_size']
N = neat_config['maximum_nodes']
C = neat_config['maximum_connections']
pop_nodes = pop[:P * N * 5].reshape((P, N, 5))
pop_cons = pop[P * N * 5:].reshape((P, C, 4))
u_pop_cons = vmap(unflatten_connections)(pop_nodes, pop_cons)
pop_seqs = vmap(topological_sort)(pop_nodes, u_pop_cons)
return pop_seqs, pop_nodes, u_pop_cons
# special policy for cartpole
def neat_forward(genome, x):
res = origin_forward_func(x, *genome)[0]
out = jnp.where(res > 0.5, 1, 0)
return out
forward_func = lambda pop, x: origin_forward_func(x, *pop)
problem = Gym(
policy=jit(vmap(neat_forward)),
env_name="CartPole-v1",
env_options={"new_step_api": True},
pop_size=40,
)
# create a pipeline
pipeline = evox.pipelines.StdPipeline(
algorithm=NEAT(neat_config),
problem=problem,
pop_transform=jit(neat_transform),
fitness_transform=monitor.record_fit,
)
# init the pipeline
state = pipeline.init(key)
# run the pipeline for 10 steps
for i in range(10):
state = pipeline.step(state)
print(monitor.get_min_fitness())
# obtain 500
min_fitness = monitor.get_min_fitness()
print(min_fitness)

View File

@@ -1,14 +0,0 @@
import gym
env = gym.make("CartPole-v1", new_step_api=True)
print(env.reset())
obs = env.reset()
print(obs)
while True:
action = env.action_space.sample()
obs, reward, terminate, truncate, info = env.step(action)
print(obs, info)
if terminate | truncate:
break

View File

@@ -1,22 +0,0 @@
[basic]
num_inputs = 2
num_outputs = 1
maximum_nodes = 50
maximum_connections = 50
maximum_species = 10
forward_way = "single"
random_seed = 42
[population]
pop_size = 100
[gene-activation]
activation_default = "sigmoid"
activation_option_names = ['sigmoid', 'tanh', 'sin', 'gauss', 'relu', 'identity', 'inv', 'log', 'exp', 'abs', 'hat', 'square']
activation_replace_rate = 0.1
[gene-aggregation]
aggregation_default = "sum"
aggregation_option_names = ['sum', 'product', 'max', 'min', 'maxabs', 'median', 'mean']
aggregation_replace_rate = 0.1

View File

@@ -1,63 +0,0 @@
import evox
import jax
from jax import jit, vmap, numpy as jnp
from configs import Configer
from algorithms.neat import create_forward_function, topological_sort, unflatten_connections
from evox_adaptor import NEAT, Gym
if __name__ == '__main__':
batch_policy = True
key = jax.random.PRNGKey(42)
monitor = evox.monitors.StdSOMonitor()
neat_config = Configer.load_config('mountain_car.ini')
origin_forward_func = create_forward_function(neat_config)
def neat_transform(pop):
P = neat_config['pop_size']
N = neat_config['maximum_nodes']
C = neat_config['maximum_connections']
pop_nodes = pop[:P * N * 5].reshape((P, N, 5))
pop_cons = pop[P * N * 5:].reshape((P, C, 4))
u_pop_cons = vmap(unflatten_connections)(pop_nodes, pop_cons)
pop_seqs = vmap(topological_sort)(pop_nodes, u_pop_cons)
return pop_seqs, pop_nodes, u_pop_cons
# special policy for mountain car
def neat_forward(genome, x):
res = origin_forward_func(x, *genome)
out = jnp.tanh(res) # (-1, 1)
return out
forward_func = lambda pop, x: origin_forward_func(x, *pop)
problem = Gym(
policy=jit(vmap(neat_forward)),
env_name="MountainCarContinuous-v0",
env_options={"new_step_api": True},
pop_size=100,
)
# create a pipeline
pipeline = evox.pipelines.StdPipeline(
algorithm=NEAT(neat_config),
problem=problem,
pop_transform=jit(neat_transform),
fitness_transform=monitor.record_fit,
)
# init the pipeline
state = pipeline.init(key)
# run the pipeline for 10 steps
for i in range(30):
state = pipeline.step(state)
print(i, monitor.get_min_fitness())
# obtain 98.91529684268514
min_fitness = monitor.get_min_fitness()
print(min_fitness)

View File

@@ -1,18 +0,0 @@
from functools import partial
from jax import numpy as jnp, jit
@partial(jit, static_argnames=['reverse'])
def rank_element(array, reverse=False):
"""
rank the element in the array.
if reverse is True, the rank is from large to small.
"""
if reverse:
array = -array
return jnp.argsort(jnp.argsort(array))
a = jnp.array([1, 5, 3, 5, 2, 1, 0])
print(rank_element(a, reverse=True))

14
examples/state_test.py Normal file
View File

@@ -0,0 +1,14 @@
import jax
from algorithm.state import State
@jax.jit
def func(state: State, a):
return state.update(a=a)
state = State(c=1, b=2)
print(state)
state = func(state, 1111111)
print(state)

View File

@@ -1,5 +0,0 @@
[basic]
forward_way = "common"
[population]
fitness_threshold = 4

View File

@@ -1,31 +0,0 @@
import jax
import numpy as np
from configs import Configer
from pipeline import Pipeline
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
def main():
config = Configer.load_config("xor.ini")
pipeline = Pipeline(config)
nodes, cons = pipeline.auto_run(evaluate)
# g = Genome(nodes, cons, config)
# print(g)
if __name__ == '__main__':
main()

View File

@@ -1,47 +0,0 @@
[basic]
num_inputs = 3
num_outputs = 1
maximum_nodes = 50
maximum_connections = 50
maximum_species = 10
forward_way = "common"
batch_size = 4
random_seed = 42
[population]
fitness_threshold = 8
generation_limit = 1000
fitness_criterion = "max"
pop_size = 10000
[genome]
compatibility_disjoint = 1.0
compatibility_weight = 0.5
conn_add_prob = 0.4
conn_add_trials = 1
conn_delete_prob = 0
node_add_prob = 0.2
node_delete_prob = 0
[species]
compatibility_threshold = 3
species_elitism = 1
max_stagnation = 15
genome_elitism = 2
survival_threshold = 0.2
min_species_size = 1
spawn_number_move_rate = 0.5
[gene-bias]
bias_init_mean = 0.0
bias_init_std = 1.0
bias_mutate_power = 0.5
bias_mutate_rate = 0.7
bias_replace_rate = 0.1
[gene-weight]
weight_init_mean = 0.0
weight_init_std = 1.0
weight_mutate_power = 0.5
weight_mutate_rate = 0.8
weight_replace_rate = 0.1

View File

@@ -1,31 +0,0 @@
import jax
import numpy as np
from configs import Configer
from pipeline import Pipeline
xor_inputs = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0], [1], [0], [0], [1]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 8 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
def main():
config = Configer.load_config("xor3d.ini")
pipeline = Pipeline(config)
nodes, cons = pipeline.auto_run(evaluate)
# g = Genome(nodes, cons, config)
# print(g)
if __name__ == '__main__':
main()

View File

@@ -1,158 +0,0 @@
import time
from typing import Union, Callable
import numpy as np
import jax
from jax import jit, vmap
from algorithms import neat
from configs.configer import Configer
class Pipeline:
"""
Neat algorithm pipeline.
"""
def __init__(self, config):
self.config = config # global config
self.jit_config = Configer.create_jit_config(config)
self.best_genome = None
self.neat_states = neat.initialize(config)
self.best_fitness = float('-inf')
self.generation_timestamp = time.time()
self.evaluate_time = 0
(
self.randkey,
self.pop_nodes,
self.pop_cons,
self.species_info,
self.idx2species,
self.center_nodes,
self.center_cons,
self.generation,
self.next_node_key,
self.next_species_key,
) = neat.initialize(config)
self.forward = neat.create_forward_function(config)
self.pop_unflatten_connections = jit(vmap(neat.unflatten_connections))
self.pop_topological_sort = jit(vmap(neat.topological_sort))
def ask(self):
"""
Creates a function that receives a genome and returns a forward function.
There are 3 types of config['forward_way']: {'single', 'pop', 'common'}
single:
Create pop_size number of forward functions.
Each function receive (input_size, ) and returns (output_size, )
e.g. RL task
batch:
Create pop_size number of forward functions.
Each function receive (input_size, ) and returns (output_size, )
some task need to calculate the fitness of a batch of inputs
pop:
Create a single forward function, which use only once calculation for the population.
The function receives (pop_size, batch_size, input_size) and returns (pop_size, batch_size, output_size)
common:
Special case of pop. The population has the same inputs.
The function receives (batch_size, input_size) and returns (pop_size, batch_size, output_size)
e.g. numerical regression; Hyper-NEAT
"""
u_pop_cons = self.pop_unflatten_connections(self.pop_nodes, self.pop_cons)
pop_seqs = self.pop_topological_sort(self.pop_nodes, u_pop_cons)
# only common mode is supported currently
if self.config['forward_way'] == 'single' or self.config['forward_way'] == 'batch':
# carry data to cpu for fast iteration
pop_seqs, self.pop_nodes, self.pop_cons = jax.device_get((pop_seqs, self.pop_nodes, self.pop_cons))
funcs = [lambda x: self.forward(x, seqs, nodes, u_cons)
for seqs, nodes, u_cons in zip(pop_seqs, self.pop_nodes, self.pop_cons)]
return funcs
elif self.config['forward_way'] == 'pop' or self.config['forward_way'] == 'common':
return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons)
else:
raise NotImplementedError(f"forward_way {self.config['forward_way']} is not supported")
def tell(self, fitness):
(
self.randkey,
self.pop_nodes,
self.pop_cons,
self.species_info,
self.idx2species,
self.center_nodes,
self.center_cons,
self.generation,
self.next_node_key,
self.next_species_key,
) = neat.tell(
fitness,
self.randkey,
self.pop_nodes,
self.pop_cons,
self.species_info,
self.idx2species,
self.center_nodes,
self.center_cons,
self.generation,
self.next_node_key,
self.next_species_key,
self.jit_config
)
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config['generation_limit']):
forward_func = self.ask()
tic = time.time()
fitnesses = fitness_func(forward_func)
self.evaluate_time += time.time() - tic
# assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!"
if analysis is not None:
if analysis == "default":
self.default_analysis(fitnesses)
else:
assert callable(analysis), f"What the fuck you passed in? A {analysis}?"
analysis(fitnesses)
if max(fitnesses) >= self.config['fitness_threshold']:
print("Fitness limit reached!")
return self.best_genome
self.tell(fitnesses)
print("Generation limit reached!")
return self.best_genome
def default_analysis(self, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
new_timestamp = time.time()
cost_time = new_timestamp - self.generation_timestamp
self.generation_timestamp = new_timestamp
max_idx = np.argmax(fitnesses)
if fitnesses[max_idx] > self.best_fitness:
self.best_fitness = fitnesses[max_idx]
self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx])
member_count = jax.device_get(self.species_info[:, 3])
species_sizes = [int(i) for i in member_count if i > 0]
print(f"Generation: {self.generation}",
f"species: {len(species_sizes)}, {species_sizes}",
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}")