diff --git a/README.md b/README.md deleted file mode 100644 index 4389438..0000000 --- a/README.md +++ /dev/null @@ -1,12 +0,0 @@ -# NEATAX: Tensorized NEAT Implementation for Parallel Hardware Accelaration - -NEATAX is a powerful tool that utilizes JAX to implement the NEAT (NeuroEvolution of Augmenting Topologies) algorithm. It provides support for parallel execution of tasks such as forward network computation, mutation, and crossover at the population level. - -## Performance - -One of the standout features of NEATAX is its speed. Compared to traditional CPU implementations, NEATAX significantly improves the efficiency of the NEAT algorithm. It has been observed to boost the running speed of the NEAT algorithm by more than 10 times, offering considerable advantage in larger-scale and time-sensitive applications. - -## Installization -by git clone -need JAX environment - diff --git a/algorithms/__init__.py b/algorithm/__init__.py similarity index 100% rename from algorithms/__init__.py rename to algorithm/__init__.py diff --git a/examples/evox_/__init__.py b/algorithm/hyperneat/__init__.py similarity index 100% rename from examples/evox_/__init__.py rename to algorithm/hyperneat/__init__.py diff --git a/algorithm/neat/__init__.py b/algorithm/neat/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/algorithm/state.py b/algorithm/state.py new file mode 100644 index 0000000..d20774c --- /dev/null +++ b/algorithm/state.py @@ -0,0 +1,29 @@ +from jax.tree_util import register_pytree_node_class, tree_map + + +@register_pytree_node_class +class State: + + def __init__(self, **kwargs): + self.__dict__['state_dict'] = kwargs + + def update(self, **kwargs): + return State(**{**self.state_dict, **kwargs}) + + def __getattr__(self, name): + return self.state_dict[name] + + def __setattr__(self, name, value): + raise AttributeError("State is immutable") + + def __repr__(self): + return f"State ({self.state_dict})" + + def tree_flatten(self): + children = list(self.state_dict.values()) + aux_data = list(self.state_dict.keys()) + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls(**dict(zip(aux_data, children))) diff --git a/algorithms/neat/__init__.py b/algorithms/neat/__init__.py deleted file mode 100644 index 37f1924..0000000 --- a/algorithms/neat/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -contains operations on a single genome. e.g. forward, mutate, crossover, etc. -""" -from .genome import create_forward_function, topological_sort, unflatten_connections, initialize_genomes -from .population import update_species, create_next_generation, speciate, tell, initialize - -from .genome.activations import act_name2func -from .genome.aggregations import agg_name2func - -from .visualize import Genome diff --git a/algorithms/neat/genome/__init__.py b/algorithms/neat/genome/__init__.py deleted file mode 100644 index b98155f..0000000 --- a/algorithms/neat/genome/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .mutate import mutate -from .distance import distance -from .crossover import crossover -from .graph import topological_sort, check_cycles -from .utils import unflatten_connections, I_INT, fetch_first, rank_elements -from .forward import create_forward_function -from .genome import initialize_genomes diff --git a/algorithms/neat/genome/activations.py b/algorithms/neat/genome/activations.py deleted file mode 100644 index 3cd828e..0000000 --- a/algorithms/neat/genome/activations.py +++ /dev/null @@ -1,106 +0,0 @@ -import jax.numpy as jnp - - -def sigmoid_act(z): - z = jnp.clip(z * 5, -60, 60) - return 1 / (1 + jnp.exp(-z)) - - -def tanh_act(z): - z = jnp.clip(z * 2.5, -60, 60) - return jnp.tanh(z) - - -def sin_act(z): - z = jnp.clip(z * 5, -60, 60) - return jnp.sin(z) - - -def gauss_act(z): - z = jnp.clip(z * 5, -3.4, 3.4) - return jnp.exp(-z ** 2) - - -def relu_act(z): - return jnp.maximum(z, 0) - - -def elu_act(z): - return jnp.where(z > 0, z, jnp.exp(z) - 1) - - -def lelu_act(z): - leaky = 0.005 - return jnp.where(z > 0, z, leaky * z) - - -def selu_act(z): - lam = 1.0507009873554804934193349852946 - alpha = 1.6732632423543772848170429916717 - return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1)) - - -def softplus_act(z): - z = jnp.clip(z * 5, -60, 60) - return 0.2 * jnp.log(1 + jnp.exp(z)) - - -def identity_act(z): - return z - - -def clamped_act(z): - return jnp.clip(z, -1, 1) - - -def inv_act(z): - z = jnp.maximum(z, 1e-7) - return 1 / z - - -def log_act(z): - z = jnp.maximum(z, 1e-7) - return jnp.log(z) - - -def exp_act(z): - z = jnp.clip(z, -60, 60) - return jnp.exp(z) - - -def abs_act(z): - return jnp.abs(z) - - -def hat_act(z): - return jnp.maximum(0, 1 - jnp.abs(z)) - - -def square_act(z): - return z ** 2 - - -def cube_act(z): - return z ** 3 - - -act_name2func = { - 'sigmoid': sigmoid_act, - 'tanh': tanh_act, - 'sin': sin_act, - 'gauss': gauss_act, - 'relu': relu_act, - 'elu': elu_act, - 'lelu': lelu_act, - 'selu': selu_act, - 'softplus': softplus_act, - 'identity': identity_act, - 'clamped': clamped_act, - 'inv': inv_act, - 'log': log_act, - 'exp': exp_act, - 'abs': abs_act, - 'hat': hat_act, - 'square': square_act, - 'cube': cube_act, -} diff --git a/algorithms/neat/genome/aggregations.py b/algorithms/neat/genome/aggregations.py deleted file mode 100644 index 81c61c9..0000000 --- a/algorithms/neat/genome/aggregations.py +++ /dev/null @@ -1,59 +0,0 @@ -import jax.numpy as jnp - - -def sum_agg(z): - z = jnp.where(jnp.isnan(z), 0, z) - return jnp.sum(z, axis=0) - - -def product_agg(z): - z = jnp.where(jnp.isnan(z), 1, z) - return jnp.prod(z, axis=0) - - -def max_agg(z): - z = jnp.where(jnp.isnan(z), -jnp.inf, z) - return jnp.max(z, axis=0) - - -def min_agg(z): - z = jnp.where(jnp.isnan(z), jnp.inf, z) - return jnp.min(z, axis=0) - - -def maxabs_agg(z): - z = jnp.where(jnp.isnan(z), 0, z) - abs_z = jnp.abs(z) - max_abs_index = jnp.argmax(abs_z) - return z[max_abs_index] - - -def median_agg(z): - non_nan_mask = ~jnp.isnan(z) - n = jnp.sum(non_nan_mask, axis=0) - - z = jnp.sort(z) # sort - - idx1, idx2 = (n - 1) // 2, n // 2 - median = (z[idx1] + z[idx2]) / 2 - - return median - - -def mean_agg(z): - non_zero_mask = ~jnp.isnan(z) - valid_values_sum = sum_agg(z) - valid_values_count = jnp.sum(non_zero_mask, axis=0) - mean_without_zeros = valid_values_sum / valid_values_count - return mean_without_zeros - - -agg_name2func = { - 'sum': sum_agg, - 'product': product_agg, - 'max': max_agg, - 'min': min_agg, - 'maxabs': maxabs_agg, - 'median': median_agg, - 'mean': mean_agg, -} diff --git a/algorithms/neat/genome/crossover.py b/algorithms/neat/genome/crossover.py deleted file mode 100644 index c27fa9b..0000000 --- a/algorithms/neat/genome/crossover.py +++ /dev/null @@ -1,80 +0,0 @@ -""" -Crossover two genomes to generate a new genome. -The calculation method is the same as the crossover operation in NEAT-python. -See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.configure_crossover -""" -from typing import Tuple - -import jax -from jax import jit, Array, numpy as jnp - - -@jit -def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) -> Tuple[Array, Array]: - """ - use genome1 and genome2 to generate a new genome - notice that genome1 should have higher fitness than genome2 (genome1 is winner!) - :param randkey: - :param nodes1: - :param cons1: - :param nodes2: - :param cons2: - :return: - """ - randkey_1, randkey_2 = jax.random.split(randkey) - - # crossover nodes - keys1, keys2 = nodes1[:, 0], nodes2[:, 0] - # make homologous genes align in nodes2 align with nodes1 - nodes2 = align_array(keys1, keys2, nodes2, 'node') - - # For not homologous genes, use the value of nodes1(winner) - # For homologous genes, use the crossover result between nodes1 and nodes2 - new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2)) - - # crossover connections - con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2] - cons2 = align_array(con_keys1, con_keys2, cons2, 'connection') - new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2)) - - return new_nodes, new_cons - - -def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array: - """ - After I review this code, I found that it is the most difficult part of the code. Please never change it! - make ar2 align with ar1. - :param seq1: - :param seq2: - :param ar2: - :param gene_type: - :return: - align means to intersect part of ar2 will be at the same position as ar1, - non-intersect part of ar2 will be set to Nan - """ - seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :] - mask = (seq1 == seq2) & (~jnp.isnan(seq1)) - - if gene_type == 'connection': - mask = jnp.all(mask, axis=2) - - intersect_mask = mask.any(axis=1) - idx = jnp.arange(0, len(seq1)) - idx_fixed = jnp.dot(mask, idx) - - refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan) - - return refactor_ar2 - - -def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array: - """ - crossover two genes - :param rand_key: - :param g1: - :param g2: - :return: - only gene with the same key will be crossover, thus don't need to consider change key - """ - r = jax.random.uniform(rand_key, shape=g1.shape) - return jnp.where(r > 0.5, g1, g2) diff --git a/algorithms/neat/genome/distance.py b/algorithms/neat/genome/distance.py deleted file mode 100644 index 4eacae4..0000000 --- a/algorithms/neat/genome/distance.py +++ /dev/null @@ -1,118 +0,0 @@ -""" -Calculate the distance between two genomes. -The calculation method is the same as the distance calculation in NEAT-python. -See https://github.com/CodeReclaimers/neat-python/blob/master/neat/genome.py -""" -from typing import Dict - -from jax import jit, vmap, Array, numpy as jnp - -from .utils import EMPTY_NODE, EMPTY_CON - - -@jit -def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, jit_config: Dict) -> Array: - """ - Calculate the distance between two genomes. - args: - nodes1: Array(N, 5) - cons1: Array(C, 4) - nodes2: Array(N, 5) - cons2: Array(C, 4) - returns: - distance: Array(, ) - """ - nd = node_distance(nodes1, nodes2, jit_config) # node distance - cd = connection_distance(cons1, cons2, jit_config) # connection distance - return nd + cd - - -@jit -def node_distance(nodes1: Array, nodes2: Array, jit_config: Dict): - """ - Calculate the distance between nodes of two genomes. - """ - # statistics nodes count of two genomes - node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0])) - node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0])) - max_cnt = jnp.maximum(node_cnt1, node_cnt2) - - # align homologous nodes - # this process is similar to np.intersect1d. - nodes = jnp.concatenate((nodes1, nodes2), axis=0) - keys = nodes[:, 0] - sorted_indices = jnp.argsort(keys, axis=0) - nodes = nodes[sorted_indices] - nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end - fr, sr = nodes[:-1], nodes[1:] # first row, second row - - # flag location of homologous nodes - intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0]) - - # calculate the count of non_homologous of two genomes - non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask) - - # calculate the distance of homologous nodes - hnd = vmap(homologous_node_distance)(fr, sr) - hnd = jnp.where(jnp.isnan(hnd), 0, hnd) - homologous_distance = jnp.sum(hnd * intersect_mask) - - val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[ - 'compatibility_weight'] - - return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division - - -@jit -def connection_distance(cons1: Array, cons2: Array, jit_config: Dict): - """ - Calculate the distance between connections of two genomes. - Similar process as node_distance. - """ - con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0])) - con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0])) - max_cnt = jnp.maximum(con_cnt1, con_cnt2) - - cons = jnp.concatenate((cons1, cons2), axis=0) - keys = cons[:, :2] - sorted_indices = jnp.lexsort(keys.T[::-1]) - cons = cons[sorted_indices] - cons = jnp.concatenate([cons, EMPTY_CON], axis=0) # add a nan row to the end - fr, sr = cons[:-1], cons[1:] # first row, second row - - # both genome has such connection - intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0]) - - non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask) - hcd = vmap(homologous_connection_distance)(fr, sr) - hcd = jnp.where(jnp.isnan(hcd), 0, hcd) - homologous_distance = jnp.sum(hcd * intersect_mask) - - val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[ - 'compatibility_weight'] - - return jnp.where(max_cnt == 0, 0, val / max_cnt) - - -@jit -def homologous_node_distance(n1: Array, n2: Array): - """ - Calculate the distance between two homologous nodes. - """ - d = 0 - d += jnp.abs(n1[1] - n2[1]) # bias - d += jnp.abs(n1[2] - n2[2]) # response - d += n1[3] != n2[3] # activation - d += n1[4] != n2[4] # aggregation - return d - - -@jit -def homologous_connection_distance(c1: Array, c2: Array): - """ - Calculate the distance between two homologous connections. - """ - d = 0 - d += jnp.abs(c1[2] - c2[2]) # weight - d += c1[3] != c2[3] # enable - return d diff --git a/algorithms/neat/genome/forward.py b/algorithms/neat/genome/forward.py deleted file mode 100644 index 2d95973..0000000 --- a/algorithms/neat/genome/forward.py +++ /dev/null @@ -1,108 +0,0 @@ -import jax -from jax import Array, numpy as jnp, jit, vmap - -from .utils import I_INT -from .activations import act_name2func -from .aggregations import agg_name2func - - -def create_forward_function(config): - """ - meta method to create forward function - """ - config['activation_funcs'] = [act_name2func[name] for name in config['activation_option_names']] - config['aggregation_funcs'] = [agg_name2func[name] for name in config['aggregation_option_names']] - - def act(idx, z): - """ - calculate activation function for each node - """ - idx = jnp.asarray(idx, dtype=jnp.int32) - # change idx from float to int - res = jax.lax.switch(idx, config['activation_funcs'], z) - return res - - def agg(idx, z): - """ - calculate activation function for inputs of node - """ - idx = jnp.asarray(idx, dtype=jnp.int32) - - def all_nan(): - return 0. - - def not_all_nan(): - return jax.lax.switch(idx, config['aggregation_funcs'], z) - - return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan) - - def forward(inputs: Array, cal_seqs: Array, nodes: Array, cons: Array) -> Array: - """ - jax forward for single input shaped (input_num, ) - nodes, connections are a single genome - - :argument inputs: (input_num, ) - :argument cal_seqs: (N, ) - :argument nodes: (N, 5) - :argument connections: (2, N, N) - - :return (output_num, ) - """ - - input_idx = config['input_idx'] - output_idx = config['output_idx'] - - N = nodes.shape[0] - ini_vals = jnp.full((N,), jnp.nan) - ini_vals = ini_vals.at[input_idx].set(inputs) - - weights = jnp.where(jnp.isnan(cons[1, :, :]), jnp.nan, cons[0, :, :]) # enabled - - def cond_fun(carry): - values, idx = carry - return (idx < N) & (cal_seqs[idx] != I_INT) - - def body_func(carry): - values, idx = carry - i = cal_seqs[idx] - - def hit(): - ins = values * weights[:, i] - z = agg(nodes[i, 4], ins) # z = agg(ins) - z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias - z = act(nodes[i, 3], z) # z = act(z) - - new_values = values.at[i].set(z) - return new_values - - def miss(): - return values - - # the val of input nodes is obtained by the task, not by calculation - values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit) - return values, idx + 1 - - vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0)) - - return vals[output_idx] - - # (batch_size, inputs_nums) -> (batch_size, outputs_nums) - batch_forward = vmap(forward, in_axes=(0, None, None, None)) - - # (pop_size, batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums) - pop_batch_forward = vmap(batch_forward, in_axes=(0, 0, 0, 0)) - - # (batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums) - common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0)) - - if config['forward_way'] == 'single': - return jit(forward) - - if config['forward_way'] == 'batch': - return jit(batch_forward) - - elif config['forward_way'] == 'pop': - return jit(pop_batch_forward) - - elif config['forward_way'] == 'common': - return jit(common_forward) diff --git a/algorithms/neat/genome/genome.py b/algorithms/neat/genome/genome.py deleted file mode 100644 index 67aac68..0000000 --- a/algorithms/neat/genome/genome.py +++ /dev/null @@ -1,132 +0,0 @@ -""" -Vectorization of genome representation. - -Utilizes Tuple[nodes: Array(N, 5), connections: Array(C, 4)] to encode the genome, where: -nodes: [key, bias, response, act, agg] -connections: [in_key, out_key, weight, enable] -N: Maximum number of nodes in the network. -C: Maximum number of connections in the network. -""" - -from typing import Tuple, Dict - -import numpy as np -from numpy.typing import NDArray -from jax import jit, numpy as jnp - -from .utils import fetch_first - - -def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]: - """ - Initialize genomes with default values. - - Args: - N (int): Maximum number of nodes in the network. - C (int): Maximum number of connections in the network. - config (Dict): Configuration dictionary. - - Returns: - Tuple[NDArray, NDArray, NDArray, NDArray]: pop_nodes, pop_connections, input_idx, and output_idx arrays. - """ - # Reserve one row for potential mutation adding an extra node - assert config['num_inputs'] + config['num_outputs'] + 1 <= N, \ - f"Too small N: {N} for input_size: {config['num_inputs']} and output_size: {config['num_inputs']}!" - - assert config['num_inputs'] * config['num_outputs'] + 1 <= C, \ - f"Too small C: {C} for input_size: {config['num_inputs']} and output_size: {config['num_outputs']}!" - - pop_nodes = np.full((config['pop_size'], N, 5), np.nan, dtype=np.float32) - pop_cons = np.full((config['pop_size'], C, 4), np.nan, dtype=np.float32) - input_idx = config['input_idx'] - output_idx = config['output_idx'] - - pop_nodes[:, input_idx, 0] = input_idx - pop_nodes[:, output_idx, 0] = output_idx - - # pop_nodes[:, output_idx, 1] = config['bias_init_mean'] - pop_nodes[:, output_idx, 1] = np.random.normal(loc=config['bias_init_mean'], scale=config['bias_init_std'], - size=(config['pop_size'], 1)) - pop_nodes[:, output_idx, 2] = np.random.normal(loc=config['response_init_mean'], scale=config['response_init_std'], - size=(config['pop_size'], 1)) - pop_nodes[:, output_idx, 3] = np.random.choice(config['activation_options'], size=(config['pop_size'], 1)) - pop_nodes[:, output_idx, 4] = np.random.choice(config['aggregation_options'], size=(config['pop_size'], 1)) - - grid_a, grid_b = np.meshgrid(input_idx, output_idx) - grid_a, grid_b = grid_a.flatten(), grid_b.flatten() - - p = config['num_inputs'] * config['num_outputs'] - pop_cons[:, :p, 0] = grid_a - pop_cons[:, :p, 1] = grid_b - pop_cons[:, :p, 2] = np.random.normal(loc=config['weight_init_mean'], scale=config['weight_init_std'], - size=(config['pop_size'], p)) - pop_cons[:, :p, 3] = 1 - - return pop_nodes, pop_cons - - -@jit -def add_node(nodes: NDArray, cons: NDArray, new_key: int, - bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[NDArray, NDArray]: - """ - Add a new node to the genome. - The new node will place at the first NaN row. - """ - exist_keys = nodes[:, 0] - idx = fetch_first(jnp.isnan(exist_keys)) - nodes = nodes.at[idx].set(jnp.array([new_key, bias, response, act, agg])) - return nodes, cons - - -@jit -def delete_node(nodes: NDArray, cons: NDArray, node_key: int) -> Tuple[NDArray, NDArray]: - """ - Delete a node from the genome. Only delete the node, regardless of connections. - Delete the node by its key. - """ - node_keys = nodes[:, 0] - idx = fetch_first(node_keys == node_key) - return delete_node_by_idx(nodes, cons, idx) - - -@jit -def delete_node_by_idx(nodes: NDArray, cons: NDArray, idx: int) -> Tuple[NDArray, NDArray]: - """ - Delete a node from the genome. Only delete the node, regardless of connections. - Delete the node by its idx. - """ - nodes = nodes.at[idx].set(np.nan) - return nodes, cons - - -@jit -def add_connection(nodes: NDArray, cons: NDArray, i_key: int, o_key: int, - weight: float = 1.0, enabled: bool = True) -> Tuple[NDArray, NDArray]: - """ - Add a new connection to the genome. - The new connection will place at the first NaN row. - """ - con_keys = cons[:, 0] - idx = fetch_first(jnp.isnan(con_keys)) - cons = cons.at[idx].set(jnp.array([i_key, o_key, weight, enabled])) - return nodes, cons - - -@jit -def delete_connection(nodes: NDArray, cons: NDArray, i_key: int, o_key: int) -> Tuple[NDArray, NDArray]: - """ - Delete a connection from the genome. - Delete the connection by its input and output node keys. - """ - idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key)) - return delete_connection_by_idx(nodes, cons, idx) - - -@jit -def delete_connection_by_idx(nodes: NDArray, cons: NDArray, idx: int) -> Tuple[NDArray, NDArray]: - """ - Delete a connection from the genome. - Delete the connection by its idx. - """ - cons = cons.at[idx].set(np.nan) - return nodes, cons diff --git a/algorithms/neat/genome/graph.py b/algorithms/neat/genome/graph.py deleted file mode 100644 index 0dda1ee..0000000 --- a/algorithms/neat/genome/graph.py +++ /dev/null @@ -1,167 +0,0 @@ -""" -Some graph algorithm implemented in jax. -Only used in feed-forward networks. -""" - -import jax -from jax import jit, Array, numpy as jnp - -from algorithms.neat.genome.utils import fetch_first, I_INT - - -@jit -def topological_sort(nodes: Array, connections: Array) -> Array: - """ - a jit-able version of topological_sort! that's crazy! - :param nodes: nodes array - :param connections: connections array - :return: topological sorted sequence - - Example: - nodes = jnp.array([ - [0], - [1], - [2], - [3] - ]) - connections = jnp.array([ - [ - [0, 0, 1, 0], - [0, 0, 1, 1], - [0, 0, 0, 1], - [0, 0, 0, 0] - ], - [ - [0, 0, 1, 0], - [0, 0, 1, 1], - [0, 0, 0, 1], - [0, 0, 0, 0] - ] - ]) - - topological_sort(nodes, connections) -> [0, 1, 2, 3] - """ - connections_enable = connections[1, :, :] == 1 # forward function. thus use enable - in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0)) - res = jnp.full(in_degree.shape, I_INT) - - def cond_fun(carry): - res_, idx_, in_degree_ = carry - i = fetch_first(in_degree_ == 0.) - return i != I_INT - - def body_func(carry): - res_, idx_, in_degree_ = carry - i = fetch_first(in_degree_ == 0.) - - # add to res and flag it is already in it - res_ = res_.at[idx_].set(i) - in_degree_ = in_degree_.at[i].set(-1) - - # decrease in_degree of all its children - children = connections_enable[i, :] - in_degree_ = jnp.where(children, in_degree_ - 1, in_degree_) - return res_, idx_ + 1, in_degree_ - - res, _, _ = jax.lax.while_loop(cond_fun, body_func, (res, 0, in_degree)) - return res - - -@jit -def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array: - """ - Check whether a new connection (from_idx -> to_idx) will cause a cycle. - - :param nodes: JAX array - The array of nodes. - :param connections: JAX array - The array of connections. - :param from_idx: int - The index of the starting node. - :param to_idx: int - The index of the ending node. - :return: JAX array - An array indicating if there is a cycle caused by the new connection. - - Example: - nodes = jnp.array([ - [0], - [1], - [2], - [3] - ]) - connections = jnp.array([ - [ - [0, 0, 1, 0], - [0, 0, 1, 1], - [0, 0, 0, 1], - [0, 0, 0, 0] - ], - [ - [0, 0, 1, 0], - [0, 0, 1, 1], - [0, 0, 0, 1], - [0, 0, 0, 0] - ] - ]) - - check_cycles(nodes, connections, 3, 2) -> True - check_cycles(nodes, connections, 2, 3) -> False - check_cycles(nodes, connections, 0, 3) -> False - check_cycles(nodes, connections, 1, 0) -> False - """ - - connections_enable = ~jnp.isnan(connections[0, :, :]) - connections_enable = connections_enable.at[from_idx, to_idx].set(True) - - visited = jnp.full(nodes.shape[0], False) - new_visited = visited.at[to_idx].set(True) - - def cond_func(carry): - visited_, new_visited_ = carry - end_cond1 = jnp.all(visited_ == new_visited_) # no new nodes been visited - end_cond2 = new_visited_[from_idx] # the starting node has been visited - return jnp.logical_not(end_cond1 | end_cond2) - - def body_func(carry): - _, visited_ = carry - new_visited_ = jnp.dot(visited_, connections_enable) - new_visited_ = jnp.logical_or(visited_, new_visited_) - return visited_, new_visited_ - - _, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited)) - return visited[from_idx] - - -if __name__ == '__main__': - nodes = jnp.array([ - [0], - [1], - [2], - [3], - [jnp.nan] - ]) - connections = jnp.array([ - [ - [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan], - [jnp.nan, jnp.nan, 1, 1, jnp.nan], - [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan], - [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan], - [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan] - ], - [ - [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan], - [jnp.nan, jnp.nan, 1, 1, jnp.nan], - [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan], - [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan], - [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan] - ] - ] - ) - - print(topological_sort(nodes, connections)) - - print(check_cycles(nodes, connections, 3, 2)) - print(check_cycles(nodes, connections, 2, 3)) - print(check_cycles(nodes, connections, 0, 3)) - print(check_cycles(nodes, connections, 1, 0)) diff --git a/algorithms/neat/genome/mutate.py b/algorithms/neat/genome/mutate.py deleted file mode 100644 index 3fe8a70..0000000 --- a/algorithms/neat/genome/mutate.py +++ /dev/null @@ -1,349 +0,0 @@ -""" -Mutate a genome. -The calculation method is the same as the mutation operation in NEAT-python. -See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.mutate -""" -from typing import Tuple, Dict - -import jax -from jax import numpy as jnp, jit, Array - -from .utils import fetch_random, fetch_first, I_INT, unflatten_connections -from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection -from .graph import check_cycles - - -@jit -def mutate(rand_key: Array, nodes: Array, connections: Array, new_node_key: int, jit_config: Dict): - """ - :param rand_key: - :param nodes: (N, 5) - :param connections: (2, N, N) - :param new_node_key: - :param jit_config: - :return: - """ - r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5) - - # structural mutations - # mutate add node - r = rand(r1) - aux_nodes, aux_connections = mutate_add_node(r1, nodes, connections, new_node_key, jit_config) - nodes = jnp.where(r < jit_config['node_add_prob'], aux_nodes, nodes) - connections = jnp.where(r < jit_config['node_add_prob'], aux_connections, connections) - - # mutate add connection - r = rand(r2) - aux_nodes, aux_connections = mutate_add_connection(r3, nodes, connections, jit_config) - nodes = jnp.where(r < jit_config['conn_add_prob'], aux_nodes, nodes) - connections = jnp.where(r < jit_config['conn_add_prob'], aux_connections, connections) - - # mutate delete node - r = rand(r3) - aux_nodes, aux_connections = mutate_delete_node(r2, nodes, connections, jit_config) - nodes = jnp.where(r < jit_config['node_delete_prob'], aux_nodes, nodes) - connections = jnp.where(r < jit_config['node_delete_prob'], aux_connections, connections) - - # mutate delete connection - r = rand(r4) - aux_nodes, aux_connections = mutate_delete_connection(r4, nodes, connections) - nodes = jnp.where(r < jit_config['conn_delete_prob'], aux_nodes, nodes) - connections = jnp.where(r < jit_config['conn_delete_prob'], aux_connections, connections) - - # value mutations - nodes, connections = mutate_values(rand_key, nodes, connections, jit_config) - - return nodes, connections - - -def mutate_values(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]: - """ - Mutate values of nodes and connections. - - Args: - rand_key: A random key for generating random values. - nodes: A 2D array representing nodes. - cons: A 3D array representing connections. - jit_config: A dict containing configuration for jit-able functions. - - Returns: - A tuple containing mutated nodes and connections. - """ - - k1, k2, k3, k4, k5, rand_key = jax.random.split(rand_key, num=6) - - # bias - bias_new = mutate_float_values(k1, nodes[:, 1], jit_config['bias_init_mean'], jit_config['bias_init_std'], - jit_config['bias_mutate_power'], jit_config['bias_mutate_rate'], - jit_config['bias_replace_rate']) - - # response - response_new = mutate_float_values(k2, nodes[:, 2], jit_config['response_init_mean'], - jit_config['response_init_std'], jit_config['response_mutate_power'], - jit_config['response_mutate_rate'], jit_config['response_replace_rate']) - - # weight - weight_new = mutate_float_values(k3, cons[:, 2], jit_config['weight_init_mean'], jit_config['weight_init_std'], - jit_config['weight_mutate_power'], jit_config['weight_mutate_rate'], - jit_config['weight_replace_rate']) - - # activation - act_new = mutate_int_values(k4, nodes[:, 3], jit_config['activation_options'], - jit_config['activation_replace_rate']) - - # aggregation - agg_new = mutate_int_values(k5, nodes[:, 4], jit_config['aggregation_options'], - jit_config['aggregation_replace_rate']) - - # enabled - r = jax.random.uniform(rand_key, cons[:, 3].shape) - enabled_new = jnp.where(r < jit_config['enable_mutate_rate'], 1 - cons[:, 3], cons[:, 3]) - - # merge - nodes = jnp.column_stack([nodes[:, 0], bias_new, response_new, act_new, agg_new]) - cons = jnp.column_stack([cons[:, 0], cons[:, 1], weight_new, enabled_new]) - - return nodes, cons - - -def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float, - mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array: - """ - Mutate float values of a given array. - - Args: - rand_key: A random key for generating random values. - old_vals: A 1D array of float values to be mutated. - mean: Mean of the values. - std: Standard deviation of the values. - mutate_strength: Strength of the mutation. - mutate_rate: Rate of the mutation. - replace_rate: Rate of the replacement. - - Returns: - A mutated 1D array of float values. - """ - k1, k2, k3, rand_key = jax.random.split(rand_key, num=4) - noise = jax.random.normal(k1, old_vals.shape) * mutate_strength - replace = jax.random.normal(k2, old_vals.shape) * std + mean - - r = jax.random.uniform(k3, old_vals.shape) - - # default - new_vals = old_vals - - # r in [0, mutate_rate), mutate - new_vals = jnp.where(r < mutate_rate, new_vals + noise, new_vals) - - # r in [mutate_rate, mutate_rate + replace_rate), replace - new_vals = jnp.where( - (mutate_rate < r) & (r < mutate_rate + replace_rate), - replace + new_vals * 0.0, # in case of nan replace to values - new_vals - ) - - new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan) - return new_vals - - -def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array: - """ - Mutate integer values (act, agg) of a given array. - - Args: - rand_key: A random key for generating random values. - old_vals: A 1D array of integer values to be mutated. - val_list: List of the integer values. - replace_rate: Rate of the replacement. - - Returns: - A mutated 1D array of integer values. - """ - k1, k2, rand_key = jax.random.split(rand_key, num=3) - replace_val = jax.random.choice(k1, val_list, old_vals.shape) - r = jax.random.uniform(k2, old_vals.shape) - new_vals = jnp.where(r < replace_rate, replace_val + old_vals * 0.0, old_vals) # in case of nan replace to values - - return new_vals - - -def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: int, - jit_config: Dict) -> Tuple[Array, Array]: - """ - Randomly add a new node from splitting a connection. - :param rand_key: - :param new_node_key: - :param nodes: - :param cons: - :param jit_config: - :return: - """ - # randomly choose a connection - i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons) - - def nothing(): # there is no connection to split - return nodes, cons - - def successful_add_node(): - # disable the connection - new_nodes, new_cons = nodes, cons - - # set enable to false - new_cons = new_cons.at[idx, 3].set(False) - - # add a new node - new_nodes, new_cons = add_node(new_nodes, new_cons, new_node_key, bias=0, response=1, - act=jit_config['activation_default'], agg=jit_config['aggregation_default']) - - # add two new connections - w = new_cons[idx, 2] - new_nodes, new_cons = add_connection(new_nodes, new_cons, i_key, new_node_key, weight=1, enabled=True) - new_nodes, new_cons = add_connection(new_nodes, new_cons, new_node_key, o_key, weight=w, enabled=True) - return new_nodes, new_cons - - # if from_idx == I_INT, that means no connection exist, do nothing - nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_add_node) - - return nodes, cons - - -# TODO: Do we really need to delete a node? -def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]: - """ - Randomly delete a node. Input and output nodes are not allowed to be deleted. - :param rand_key: - :param nodes: - :param cons: - :param jit_config: - :return: - """ - # randomly choose a node - key, idx = choice_node_key(rand_key, nodes, jit_config['input_idx'], jit_config['output_idx'], - allow_input_keys=False, allow_output_keys=False) - - def nothing(): - return nodes, cons - - def successful_delete_node(): - # delete the node - aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, idx) - - # delete all connections - aux_cons = jnp.where(((aux_cons[:, 0] == key) | (aux_cons[:, 1] == key))[:, None], - jnp.nan, aux_cons) - - return aux_nodes, aux_cons - - nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_delete_node) - - return nodes, cons - - -def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]: - """ - Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks, - cycles are not allowed. - :param rand_key: - :param nodes: - :param cons: - :param jit_config: - :return: - """ - # randomly choose two nodes - k1, k2 = jax.random.split(rand_key, num=2) - i_key, from_idx = choice_node_key(k1, nodes, jit_config['input_idx'], jit_config['output_idx'], - allow_input_keys=True, allow_output_keys=True) - o_key, to_idx = choice_node_key(k2, nodes, jit_config['input_idx'], jit_config['output_idx'], - allow_input_keys=False, allow_output_keys=True) - - con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key)) - - def successful(): - new_nodes, new_cons = add_connection(nodes, cons, i_key, o_key, weight=1, enabled=True) - return new_nodes, new_cons - - def already_exist(): - new_cons = cons.at[con_idx, 3].set(True) - return nodes, new_cons - - def cycle(): - return nodes, cons - - is_already_exist = con_idx != I_INT - - u_cons = unflatten_connections(nodes, cons) - is_cycle = check_cycles(nodes, u_cons, from_idx, to_idx) - - choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2)) - nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful]) - return nodes, cons - - -def mutate_delete_connection(rand_key: Array, nodes: Array, cons: Array): - """ - Randomly delete a connection. - :param rand_key: - :param nodes: - :param cons: - :return: - """ - # randomly choose a connection - i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons) - - def nothing(): - return nodes, cons - - def successfully_delete_connection(): - return delete_connection_by_idx(nodes, cons, idx) - - nodes, cons = jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection) - - return nodes, cons - - -def choice_node_key(rand_key: Array, nodes: Array, - input_keys: Array, output_keys: Array, - allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]: - """ - Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node. - :param rand_key: - :param nodes: - :param input_keys: - :param output_keys: - :param allow_input_keys: - :param allow_output_keys: - :return: return its key and position(idx) - """ - - node_keys = nodes[:, 0] - mask = ~jnp.isnan(node_keys) - - if not allow_input_keys: - mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys)) - - if not allow_output_keys: - mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys)) - - idx = fetch_random(rand_key, mask) - key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan) - return key, idx - - -def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[Array, Array, Array]: - """ - Randomly choose a connection key from the given connections. - :param rand_key: - :param nodes: - :param cons: - :return: i_key, o_key, idx - """ - - idx = fetch_random(rand_key, ~jnp.isnan(cons[:, 0])) - i_key = jnp.where(idx != I_INT, cons[idx, 0], jnp.nan) - o_key = jnp.where(idx != I_INT, cons[idx, 1], jnp.nan) - - return i_key, o_key, idx - - -def rand(rand_key): - return jax.random.uniform(rand_key, ()) diff --git a/algorithms/neat/genome/utils.py b/algorithms/neat/genome/utils.py deleted file mode 100644 index 673d662..0000000 --- a/algorithms/neat/genome/utils.py +++ /dev/null @@ -1,71 +0,0 @@ -from functools import partial - -import numpy as np -import jax -from jax import numpy as jnp, Array, jit, vmap - -I_INT = np.iinfo(jnp.int32).max # infinite int -EMPTY_NODE = np.full((1, 5), jnp.nan) -EMPTY_CON = np.full((1, 4), jnp.nan) - - -@jit -def unflatten_connections(nodes: Array, cons: Array): - """ - transform the (C, 4) connections to (2, N, N) - :param nodes: (N, 5) - :param cons: (C, 4) - :return: - """ - N = nodes.shape[0] - node_keys = nodes[:, 0] - i_keys, o_keys = cons[:, 0], cons[:, 1] - i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys) - o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys) - res = jnp.full((2, N, N), jnp.nan) - - # Is interesting that jax use clip when attach data in array - # however, it will do nothing set values in an array - res = res.at[0, i_idxs, o_idxs].set(cons[:, 2]) - res = res.at[1, i_idxs, o_idxs].set(cons[:, 3]) - - return res - - -def key_to_indices(key, keys): - return fetch_first(key == keys) - - -@jit -def fetch_first(mask, default=I_INT) -> Array: - """ - fetch the first True index - :param mask: array of bool - :param default: the default value if no element satisfying the condition - :return: the index of the first element satisfying the condition. if no element satisfying the condition, return default value - """ - idx = jnp.argmax(mask) - return jnp.where(mask[idx], idx, default) - - -@jit -def fetch_random(rand_key, mask, default=I_INT) -> Array: - """ - similar to fetch_first, but fetch a random True index - """ - true_cnt = jnp.sum(mask) - cumsum = jnp.cumsum(mask) - target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1) - mask = jnp.where(true_cnt == 0, False, cumsum >= target) - return fetch_first(mask, default) - - -@partial(jit, static_argnames=['reverse']) -def rank_elements(array, reverse=False): - """ - rank the element in the array. - if reverse is True, the rank is from small to large. default large to small - """ - if not reverse: - array = -array - return jnp.argsort(jnp.argsort(array)) diff --git a/algorithms/neat/population.py b/algorithms/neat/population.py deleted file mode 100644 index 61cab13..0000000 --- a/algorithms/neat/population.py +++ /dev/null @@ -1,441 +0,0 @@ -""" -Contains operations on the population: creating the next generation and population speciation. -The value tuple (P, N, C, S) is determined when the algorithm is initialized. - P: population size - N: maximum number of nodes in any genome - C: maximum number of connections in any genome - S: maximum number of species in NEAT - -These arrays are used in the algorithm: - fitness: Array[(P,), float], the fitness of each individual - randkey: Array[2, uint], the random key - pop_nodes: Array[(P, N, 5), float], nodes part of the population. [key, bias, response, act, agg] - pop_cons: Array[(P, C, 4), float], connections part of the population. [in_node, out_node, weight, enabled] - species_info: Array[(S, 4), float], the information of each species. [key, best_score, last_update, members_count] - idx2species: Array[(P,), float], map the individual to its species keys - center_nodes: Array[(S, N, 5), float], the center nodes of each species - center_cons: Array[(S, C, 4), float], the center connections of each species - generation: int, the current generation - next_node_key: float, the next of the next node - next_species_key: float, the next of the next species - jit_config: Configer, the config used in jit-able functions -""" - -# TODO: Complete python doc - -import numpy as np -import jax -from jax import jit, vmap, Array, numpy as jnp - -from .genome import initialize_genomes, distance, mutate, crossover, fetch_first, rank_elements - - -def initialize(config): - """ - initialize the states of NEAT. - """ - - P = config['pop_size'] - N = config['maximum_nodes'] - C = config['maximum_connections'] - S = config['maximum_species'] - - randkey = jax.random.PRNGKey(config['random_seed']) - np.random.seed(config['random_seed']) - pop_nodes, pop_cons = initialize_genomes(N, C, config) - species_info = np.full((S, 4), np.nan, dtype=np.float32) - species_info[0, :] = 0, -np.inf, 0, P - idx2species = np.zeros(P, dtype=np.float32) - center_nodes = np.full((S, N, 5), np.nan, dtype=np.float32) - center_cons = np.full((S, C, 4), np.nan, dtype=np.float32) - center_nodes[0, :, :] = pop_nodes[0, :, :] - center_cons[0, :, :] = pop_cons[0, :, :] - generation = np.asarray(0, dtype=np.int32) - next_node_key = np.asarray(config['num_inputs'] + config['num_outputs'], dtype=np.float32) - next_species_key = np.asarray(1, dtype=np.float32) - - return jax.device_put([ - randkey, - pop_nodes, - pop_cons, - species_info, - idx2species, - center_nodes, - center_cons, - generation, - next_node_key, - next_species_key, - ]) - -@jit -def tell(fitness, - randkey, - pop_nodes, - pop_cons, - species_info, - idx2species, - center_nodes, - center_cons, - generation, - next_node_key, - next_species_key, - jit_config): - """ - Main update function in NEAT. - """ - generation += 1 - - k1, k2, randkey = jax.random.split(randkey, 3) - - species_info, center_nodes, center_cons, winner, loser, elite_mask = \ - update_species(k1, fitness, species_info, idx2species, center_nodes, - center_cons, generation, jit_config) - - pop_nodes, pop_cons, next_node_key = create_next_generation(k2, pop_nodes, pop_cons, winner, loser, - elite_mask, next_node_key, jit_config) - - idx2species, center_nodes, center_cons, species_info, next_species_key = speciate( - pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, next_species_key, jit_config) - - return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation, next_node_key, next_species_key - - -def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config): - """ - args: - randkey: random key - fitness: Array[(pop_size,), float], the fitness of each individual - species_keys: Array[(species_size, 4), float], the information of each species - [species_key, best_score, last_update, members_count] - idx2species: Array[(pop_size,), int], map the individual to its species - center_nodes: Array[(species_size, N, 4), float], the center nodes of each species - center_cons: Array[(species_size, C, 4), float], the center connections of each species - generation: int, current generation - jit_config: Dict, the configuration of jit functions - """ - - # update the fitness of each species - species_fitness = update_species_fitness(species_info, idx2species, fitness) - - # stagnation species - species_fitness, species_info, center_nodes, center_cons = \ - stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config) - - # sort species_info by their fitness. (push nan to the end) - sort_indices = jnp.argsort(species_fitness)[::-1] - species_info = species_info[sort_indices] - center_nodes, center_cons = center_nodes[sort_indices], center_cons[sort_indices] - - # decide the number of members of each species by their fitness - spawn_number = cal_spawn_numbers(species_info, jit_config) - # jax.debug.print("spawn_number: {}", spawn_number) - # crossover info - winner, loser, elite_mask = \ - create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config) - - return species_info, center_nodes, center_cons, winner, loser, elite_mask - - -def update_species_fitness(species_info, idx2species, fitness): - """ - obtain the fitness of the species by the fitness of each individual. - use max criterion. - """ - - def aux_func(idx): - species_key = species_info[idx, 0] - s_fitness = jnp.where(idx2species == species_key, fitness, -jnp.inf) - f = jnp.max(s_fitness) - return f - - return vmap(aux_func)(jnp.arange(species_info.shape[0])) - - -def stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config): - """ - stagnation species. - those species whose fitness is not better than the best fitness of the species for a long time will be stagnation. - elitism species never stagnation - """ - - def aux_func(idx): - s_fitness = species_fitness[idx] - species_key, best_score, last_update, members_count = species_info[idx] - st = (s_fitness <= best_score) & (generation - last_update > jit_config['max_stagnation']) - last_update = jnp.where(s_fitness > best_score, generation, last_update) - best_score = jnp.where(s_fitness > best_score, s_fitness, best_score) - # stagnation condition - return st, jnp.array([species_key, best_score, last_update, members_count]) - - spe_st, species_info = vmap(aux_func)(jnp.arange(species_info.shape[0])) - - # elite species will not be stagnation - species_rank = rank_elements(species_fitness) - spe_st = jnp.where(species_rank < jit_config['species_elitism'], False, spe_st) # elitism never stagnation - - # set stagnation species to nan - species_info = jnp.where(spe_st[:, None], jnp.nan, species_info) - center_nodes = jnp.where(spe_st[:, None, None], jnp.nan, center_nodes) - center_cons = jnp.where(spe_st[:, None, None], jnp.nan, center_cons) - species_fitness = jnp.where(spe_st, -jnp.inf, species_fitness) - - return species_fitness, species_info, center_nodes, center_cons - - -def cal_spawn_numbers(species_info, jit_config): - """ - decide the number of members of each species by their fitness rank. - the species with higher fitness will have more members - Linear ranking selection - e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2] - """ - - is_species_valid = ~jnp.isnan(species_info[:, 0]) - valid_species_num = jnp.sum(is_species_valid) - denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6 - - rank_score = valid_species_num - jnp.arange(species_info.shape[0]) # obtain [3, 2, 1] - spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17] - spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0 - - target_spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']) # calculate member - # jax.debug.print("denominator: {}, spawn_number_rate: {}, target_spawn_number: {}", denominator, spawn_number_rate, target_spawn_number) - - # Avoid too much variation of numbers in a species - previous_size = species_info[:, 3].astype(jnp.int32) - spawn_number = previous_size + (target_spawn_number - previous_size) * jit_config['spawn_number_move_rate'] - # jax.debug.print("previous_size: {}, spawn_number: {}", previous_size, spawn_number) - spawn_number = spawn_number.astype(jnp.int32) - - # spawn_number = target_spawn_number.astype(jnp.int32) - - # must control the sum of spawn_number to be equal to pop_size - error = jit_config['pop_size'] - jnp.sum(spawn_number) - spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number - - return spawn_number - - -def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config): - species_size = species_info.shape[0] - pop_size = fitness.shape[0] - s_idx = jnp.arange(species_size) - p_idx = jnp.arange(pop_size) - - # def aux_func(key, idx): - def aux_func(key, idx): - members = idx2species == species_info[idx, 0] - members_num = jnp.sum(members) - - members_fitness = jnp.where(members, fitness, -jnp.inf) - sorted_member_indices = jnp.argsort(members_fitness)[::-1] - - elite_size = jit_config['genome_elitism'] - survive_size = jnp.floor(jit_config['survival_threshold'] * members_num).astype(jnp.int32) - - select_pro = (p_idx < survive_size) / survive_size - fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro) - - # elite - fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa) - ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma) - elite = jnp.where(p_idx < elite_size, True, False) - return fa, ma, elite - - # fas, mas, elites = jax.lax.max(aux_func, (jax.random.split(randkey, species_size), s_idx)) - fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx) - - spawn_number_cum = jnp.cumsum(spawn_number) - - def aux_func(idx): - loc = jnp.argmax(idx < spawn_number_cum) - - # elite genomes are at the beginning of the species - idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx) - return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species] - - part1, part2, elite_mask = vmap(aux_func)(p_idx) - - is_part1_win = fitness[part1] >= fitness[part2] - winner = jnp.where(is_part1_win, part1, part2) - loser = jnp.where(is_part1_win, part2, part1) - - return winner, loser, elite_mask - - -def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, next_node_key, jit_config): - # prepare random keys - pop_size = pop_nodes.shape[0] - new_node_keys = jnp.arange(pop_size) + next_node_key - - k1, k2 = jax.random.split(rand_key, 2) - crossover_rand_keys = jax.random.split(k1, pop_size) - mutate_rand_keys = jax.random.split(k2, pop_size) - - # batch crossover - wpn, wpc = pop_nodes[winner], pop_cons[winner] # winner pop nodes, winner pop connections - lpn, lpc = pop_nodes[loser], pop_cons[loser] # loser pop nodes, loser pop connections - npn, npc = vmap(crossover)(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections - - # batch mutation - mutate_func = vmap(mutate, in_axes=(0, 0, 0, 0, None)) - m_npn, m_npc = mutate_func(mutate_rand_keys, npn, npc, new_node_keys, jit_config) # mutate_new_pop_nodes - - # elitism don't mutate - pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn) - pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc) - - # update next node key - all_nodes_keys = pop_nodes[:, :, 0] - max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys)) - next_node_key = max_node_key + 1 - - return pop_nodes, pop_cons, next_node_key - - -def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, next_species_key, jit_config): - """ - args: - pop_nodes: (pop_size, N, 5) - pop_cons: (pop_size, C, 4) - spe_center_nodes: (species_size, N, 5) - spe_center_cons: (species_size, C, 4) - """ - pop_size, species_size = pop_nodes.shape[0], center_nodes.shape[0] - - # prepare distance functions - o2p_distance_func = vmap(distance, in_axes=(None, None, 0, 0, None)) # one to population - - # idx to specie key - idx2specie = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species - - # the distance between genomes to its center genomes - o2c_distances = jnp.full((pop_size,), jnp.inf) - - # step 1: find new centers - def cond_func(carry): - i, i2s, cn, cc, o2c = carry - species_key = species_info[i, 0] - # jax.debug.print("{}, {}", i, species_key) - return (i < species_size) & (~jnp.isnan(species_key)) # current species is existing - - def body_func(carry): - i, i2s, cn, cc, o2c = carry - distances = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config) - - # find the closest one - closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s)) - # jax.debug.print("closest_idx: {}", closest_idx) - - i2s = i2s.at[closest_idx].set(species_info[i, 0]) - cn = cn.at[i].set(pop_nodes[closest_idx]) - cc = cc.at[i].set(pop_cons[closest_idx]) - - # the genome with closest_idx will become the new center, thus its distance to center is 0. - o2c = o2c.at[closest_idx].set(0) - - return i + 1, i2s, cn, cc, o2c - - _, idx2specie, center_nodes, center_cons, o2c_distances = \ - jax.lax.while_loop(cond_func, body_func, (0, idx2specie, center_nodes, center_cons, o2c_distances)) - - # jax.debug.print("species_info: \n{}", species_info) - # jax.debug.print("idx2specie: \n{}", idx2specie) - - # part 2: assign members to each species - def cond_func(carry): - i, i2s, cn, cc, si, o2c, nsk = carry # si is short for species_info, nsk is short for next_species_key - # jax.debug.print("i:\n{}\ni2s:\n{}\nsi:\n{}", i, i2s, si) - current_species_existed = ~jnp.isnan(si[i, 0]) - not_all_assigned = jnp.any(jnp.isnan(i2s)) - not_reach_species_upper_bounds = i < species_size - return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned) - - def body_func(carry): - i, i2s, cn, cc, si, o2c, nsk = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons - - _, i2s, scn, scc, si, o2c, nsk = jax.lax.cond( - jnp.isnan(si[i, 0]), # whether the current species is existing or not - create_new_species, # if not existing, create a new specie - update_exist_specie, # if existing, update the specie - (i, i2s, cn, cc, si, o2c, nsk) - ) - - return i + 1, i2s, scn, scc, si, o2c, nsk - - def create_new_species(carry): - i, i2s, cn, cc, si, o2c, nsk = carry - - # pick the first one who has not been assigned to any species - idx = fetch_first(jnp.isnan(i2s)) - - # assign it to the new species - # [key, best score, last update generation, members_count] - si = si.at[i].set(jnp.array([nsk, -jnp.inf, generation, 0])) - i2s = i2s.at[idx].set(nsk) - o2c = o2c.at[idx].set(0) - - # update center genomes - cn = cn.at[i].set(pop_nodes[idx]) - cc = cc.at[i].set(pop_cons[idx]) - - i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c)) - - # when a new species is created, it needs to be updated, thus do not change i - return i + 1, i2s, cn, cc, si, o2c, nsk + 1 # change to next new speciate key - - def update_exist_specie(carry): - i, i2s, cn, cc, si, o2c, nsk = carry - i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c)) - - # turn to next species - return i + 1, i2s, cn, cc, si, o2c, nsk - - def speciate_by_threshold(carry): - i, i2s, cn, cc, si, o2c = carry - - # distance between such center genome and ppo genomes - o2p_distance = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config) - close_enough_mask = o2p_distance < jit_config['compatibility_threshold'] - - # when a genome is not assigned or the distance between its current center is bigger than this center - cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c) - # jax.debug.print("{}", o2p_distance) - mask = close_enough_mask & cacheable_mask - - # update species info - i2s = jnp.where(mask, si[i, 0], i2s) - - # update distance between centers - o2c = jnp.where(mask, o2p_distance, o2c) - - return i2s, o2c - - # update idx2specie - _, idx2specie, center_nodes, center_cons, species_info, _, next_species_key = jax.lax.while_loop( - cond_func, - body_func, - (0, idx2specie, center_nodes, center_cons, species_info, o2c_distances, next_species_key) - ) - - # if there are still some pop genomes not assigned to any species, add them to the last genome - # this condition can only happen when the number of species is reached species upper bounds - idx2specie = jnp.where(jnp.isnan(idx2specie), species_info[-1, 0], idx2specie) - - # update members count - def count_members(idx): - key = species_info[idx, 0] - count = jnp.sum(idx2specie == key) - count = jnp.where(jnp.isnan(key), jnp.nan, count) - return count - - species_member_counts = vmap(count_members)(jnp.arange(species_size)) - species_info = species_info.at[:, 3].set(species_member_counts) - - return idx2specie, center_nodes, center_cons, species_info, next_species_key - - -def argmin_with_mask(arr: Array, mask: Array) -> Array: - masked_arr = jnp.where(mask, arr, jnp.inf) - min_idx = jnp.argmin(masked_arr) - return min_idx diff --git a/algorithms/neat/visualize.py b/algorithms/neat/visualize.py deleted file mode 100644 index 15e5dec..0000000 --- a/algorithms/neat/visualize.py +++ /dev/null @@ -1,112 +0,0 @@ -import jax -import numpy as np - - -class Genome: - def __init__(self, nodes, cons, config): - self.config = config - self.nodes, self.cons = array2object(nodes, cons, config) - if config['renumber_nodes']: - self.renumber() - - def __repr__(self): - return f'Genome(\n' \ - f'\tinput_keys: {self.config["input_idx"]}, \n' \ - f'\toutput_keys: {self.config["output_idx"]}, \n' \ - f'\tnodes: \n\t\t' \ - f'{self.repr_nodes()} \n' \ - f'\tconnections: \n\t\t' \ - f'{self.repr_conns()} \n)' - - def repr_nodes(self): - nodes_info = [] - for key, value in self.nodes.items(): - bias, response, act, agg = value - act_func = self.config['activation_option_names'][int(act)] if act is not None else None - agg_func = self.config['aggregation_option_names'][int(agg)] if agg is not None else None - s = f"{key}: (bias: {bias}, response: {response}, act: {act_func}, agg: {agg_func})" - nodes_info.append(s) - return ',\n\t\t'.join(nodes_info) - - def repr_conns(self): - conns_info = [] - for key, value in self.cons.items(): - weight, enabled = value - s = f"{key}: (weight: {weight}, enabled: {enabled})" - conns_info.append(s) - return ',\n\t\t'.join(conns_info) - - def renumber(self): - nodes2new_nodes = {} - new_id = len(self.config['input_idx']) + len(self.config['output_idx']) - for key in self.nodes.keys(): - if key in self.config['input_idx'] or key in self.config['output_idx']: - nodes2new_nodes[key] = key - else: - nodes2new_nodes[key] = new_id - new_id += 1 - - new_nodes, new_cons = {}, {} - for key, value in self.nodes.items(): - new_nodes[nodes2new_nodes[key]] = value - for key, value in self.cons.items(): - i_key, o_key = key - new_cons[(nodes2new_nodes[i_key], nodes2new_nodes[o_key])] = value - self.nodes = new_nodes - self.cons = new_cons - - -def array2object(nodes, cons, config): - """ - Convert a genome from array to dict. - :param nodes: (N, 5) - :param cons: (C, 4) - :return: nodes_dict[key: (bias, response, act, agg)], cons_dict[(i_key, o_key): (weight, enabled)] - """ - nodes, cons = jax.device_get((nodes, cons)) - # update nodes_dict - nodes_dict = {} - for i, node in enumerate(nodes): - if np.isnan(node[0]): - continue - key = int(node[0]) - assert key not in nodes_dict, f"Duplicate node key: {key}!" - - if key in config['input_idx']: - assert np.all(np.isnan(node[1:])), f"Input node {key} must has None bias, response, act, or agg!" - nodes_dict[key] = (None,) * 4 - else: - assert np.all( - ~np.isnan(node[1:])), f"Normal node {key} must has non-None bias, response, act, or agg!" - bias = node[1] - response = node[2] - act = node[3] - agg = node[4] - nodes_dict[key] = (bias, response, act, agg) - - # check nodes_dict - for i in config['input_idx']: - assert i in nodes_dict, f"Input node {i} not found in nodes_dict!" - - for o in config['output_idx']: - assert o in nodes_dict, f"Output node {o} not found in nodes_dict!" - - # update connections - cons_dict = {} - for i, con in enumerate(cons): - if np.all(np.isnan(con)): - pass - elif np.all(~np.isnan(con)): - i_key = int(con[0]) - o_key = int(con[1]) - if (i_key, o_key) in cons_dict: - assert False, f"Duplicate connection: {(i_key, o_key)}!" - assert i_key in nodes_dict, f"Input node {i_key} not found in nodes_dict!" - assert o_key in nodes_dict, f"Output node {o_key} not found in nodes_dict!" - weight = con[2] - enabled = (con[3] == 1) - cons_dict[(i_key, o_key)] = (weight, enabled) - else: - assert False, f"Connection {i} must has all None or all non-None!" - - return nodes_dict, cons_dict diff --git a/configs/__init__.py b/configs/__init__.py deleted file mode 100644 index dd1d67d..0000000 --- a/configs/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .configer import Configer diff --git a/configs/configer.py b/configs/configer.py deleted file mode 100644 index 4b8946b..0000000 --- a/configs/configer.py +++ /dev/null @@ -1,118 +0,0 @@ -import os -import warnings -import configparser - -import numpy as np - -# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX. -jit_config_keys = [ - "input_idx", - "output_idx", - "compatibility_disjoint", - "compatibility_weight", - "conn_add_prob", - "conn_add_trials", - "conn_delete_prob", - "node_add_prob", - "node_delete_prob", - "compatibility_threshold", - "bias_init_mean", - "bias_init_std", - "bias_mutate_power", - "bias_mutate_rate", - "bias_replace_rate", - "response_init_mean", - "response_init_std", - "response_mutate_power", - "response_mutate_rate", - "response_replace_rate", - "activation_default", - "activation_options", - "activation_replace_rate", - "aggregation_default", - "aggregation_options", - "aggregation_replace_rate", - "weight_init_mean", - "weight_init_std", - "weight_mutate_power", - "weight_mutate_rate", - "weight_replace_rate", - "enable_mutate_rate", - "max_stagnation", - "pop_size", - "genome_elitism", - "survival_threshold", - "species_elitism", - "spawn_number_move_rate" -] - - -class Configer: - - @classmethod - def __load_default_config(cls): - par_dir = os.path.dirname(os.path.abspath(__file__)) - default_config_path = os.path.join(par_dir, "default_config.ini") - return cls.__load_config(default_config_path) - - @classmethod - def __load_config(cls, config_path): - c = configparser.ConfigParser() - c.read(config_path) - config = {} - - for section in c.sections(): - for key, value in c.items(section): - config[key] = eval(value) - - return config - - @classmethod - def __check_redundant_config(cls, default_config, config): - for key in config: - if key not in default_config: - warnings.warn(f"Redundant config: {key} in {config.name}") - - @classmethod - def __complete_config(cls, default_config, config): - for key in default_config: - if key not in config: - config[key] = default_config[key] - - @classmethod - def load_config(cls, config_path=None): - default_config = cls.__load_default_config() - if config_path is None: - config = {} - elif not os.path.exists(config_path): - warnings.warn(f"config file {config_path} not exist!") - config = {} - else: - config = cls.__load_config(config_path) - - cls.__check_redundant_config(default_config, config) - cls.__complete_config(default_config, config) - - cls.refactor_activation(config) - cls.refactor_aggregation(config) - - config['input_idx'] = np.arange(config['num_inputs']) - config['output_idx'] = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs']) - - return config - - @classmethod - def refactor_activation(cls, config): - config['activation_default'] = 0 - config['activation_options'] = np.arange(len(config['activation_option_names'])) - - @classmethod - def refactor_aggregation(cls, config): - config['aggregation_default'] = 0 - config['aggregation_options'] = np.arange(len(config['aggregation_option_names'])) - - @classmethod - def create_jit_config(cls, config): - jit_config = {k: config[k] for k in jit_config_keys} - - return jit_config diff --git a/configs/default_config.ini b/configs/default_config.ini deleted file mode 100644 index e8c3be4..0000000 --- a/configs/default_config.ini +++ /dev/null @@ -1,70 +0,0 @@ -[basic] -num_inputs = 2 -num_outputs = 1 -maximum_nodes = 50 -maximum_connections = 50 -maximum_species = 10 -forward_way = "pop" -batch_size = 4 -random_seed = 0 - -[population] -fitness_threshold = 3.99999 -generation_limit = 1000 -fitness_criterion = "max" -pop_size = 10000 - -[genome] -compatibility_disjoint = 1.0 -compatibility_weight = 0.5 -conn_add_prob = 0.4 -conn_add_trials = 1 -conn_delete_prob = 0.4 -node_add_prob = 0.2 -node_delete_prob = 0.2 - -[species] -compatibility_threshold = 3.0 -species_elitism = 2 -max_stagnation = 15 -genome_elitism = 2 -survival_threshold = 0.2 -min_species_size = 1 -spawn_number_move_rate = 0.5 - -[gene-bias] -bias_init_mean = 0.0 -bias_init_std = 1.0 -bias_mutate_power = 0.5 -bias_mutate_rate = 0.7 -bias_replace_rate = 0.1 - -[gene-response] -response_init_mean = 1.0 -response_init_std = 0.0 -response_mutate_power = 0.0 -response_mutate_rate = 0.0 -response_replace_rate = 0.0 - -[gene-activation] -activation_default = "sigmoid" -activation_option_names = ["sigmoid"] -activation_replace_rate = 0.0 - -[gene-aggregation] -aggregation_default = "sum" -aggregation_option_names = ["sum"] -aggregation_replace_rate = 0.0 - -[gene-weight] -weight_init_mean = 0.0 -weight_init_std = 1.0 -weight_mutate_power = 0.5 -weight_mutate_rate = 0.8 -weight_replace_rate = 0.1 - -[gene-enable] -enable_mutate_rate = 0.01 - -[visualize] -renumber_nodes = True \ No newline at end of file diff --git a/evox_adaptor/__init__.py b/evox_adaptor/__init__.py deleted file mode 100644 index 43d3342..0000000 --- a/evox_adaptor/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .neat import NEAT -from .gym_no_distribution import Gym diff --git a/evox_adaptor/gym_no_distribution.py b/evox_adaptor/gym_no_distribution.py deleted file mode 100644 index 4a30b0c..0000000 --- a/evox_adaptor/gym_no_distribution.py +++ /dev/null @@ -1,81 +0,0 @@ -from typing import Callable - -import gym -import jax -import jax.numpy as jnp -import numpy as np - -from evox import Problem, State - - -class Gym(Problem): - def __init__( - self, - pop_size: int, - policy: Callable, - env_name: str = "CartPole-v1", - env_options: dict = None, - batch_policy: bool = True, - ): - self.pop_size = pop_size - self.env_name = env_name - self.policy = policy - self.env_options = env_options or {} - self.batch_policy = batch_policy - assert batch_policy, "Only batch policy is supported for now" - - self.envs = [gym.make(env_name, **self.env_options) for _ in range(self.pop_size)] - - super().__init__() - - def setup(self, key): - return State(key=key) - - def evaluate(self, state, pop): - key = state.key - # key, subkey = jax.random.split(state.key) - - # generate a list of seeds for gym - # seeds = jax.random.randint( - # subkey, (self.pop_size,), 0, jnp.iinfo(jnp.int32).max - # ) - - # currently use fixed seed for debugging - seeds = jax.random.randint( - key, (self.pop_size,), 0, jnp.iinfo(jnp.int32).max - ) - - seeds = seeds.tolist() # seed must be a python int, not numpy array - - fitnesses = self.__rollout(seeds, pop) - print("fitnesses info: ") - print(f"max: {np.max(fitnesses)}, min: {np.min(fitnesses)}, mean: {np.mean(fitnesses)}, std: {np.std(fitnesses)}") - - # evox uses negative fitness for minimization - return -fitnesses, State(key=key) - - def __rollout(self, seeds, pop): - observations = [env.reset(seed=seed) for env, seed in zip(self.envs, seeds)] - terminates, truncates = np.zeros((2, self.pop_size), dtype=bool) - fitnesses, rewards = np.zeros((2, self.pop_size)) - - while not np.all(terminates | truncates): - observations = np.asarray(observations) - actions = self.policy(pop, observations) - actions = jax.device_get(actions) - - for i, (action, terminate, truncate, env) in enumerate(zip(actions, terminates, truncates, self.envs)): - if terminate | truncate: - observation = np.zeros(env.observation_space.shape) - reward = 0 - else: - observation, reward, terminate, truncate, info = env.step(action) - - observations[i] = observation - rewards[i] = reward - terminates[i] = terminate - truncates[i] = truncate - - fitnesses += rewards - - return fitnesses diff --git a/evox_adaptor/neat.py b/evox_adaptor/neat.py deleted file mode 100644 index 55a6f12..0000000 --- a/evox_adaptor/neat.py +++ /dev/null @@ -1,91 +0,0 @@ -import jax.numpy as jnp - -import evox -from algorithms import neat -from configs import Configer - - -@evox.jit_class -class NEAT(evox.Algorithm): - def __init__(self, config): - self.config = config # global config - self.jit_config = Configer.create_jit_config(config) - ( - self.randkey, - self.pop_nodes, - self.pop_cons, - self.species_info, - self.idx2species, - self.center_nodes, - self.center_cons, - self.generation, - self.next_node_key, - self.next_species_key, - ) = neat.initialize(config) - super().__init__() - - def setup(self, key): - return evox.State( - randkey=self.randkey, - pop_nodes=self.pop_nodes, - pop_cons=self.pop_cons, - species_info=self.species_info, - idx2species=self.idx2species, - center_nodes=self.center_nodes, - center_cons=self.center_cons, - generation=self.generation, - next_node_key=self.next_node_key, - next_species_key=self.next_species_key, - jit_config=self.jit_config - ) - - def ask(self, state): - flatten_pop_nodes = state.pop_nodes.flatten() - flatten_pop_cons = state.pop_cons.flatten() - pop = jnp.concatenate([flatten_pop_nodes, flatten_pop_cons]) - return pop, state - - def tell(self, state, fitness): - - # evox is a minimization framework, so we need to negate the fitness - fitness = -fitness - - ( - randkey, - pop_nodes, - pop_cons, - species_info, - idx2species, - center_nodes, - center_cons, - generation, - next_node_key, - next_species_key - ) = neat.tell( - fitness, - state.randkey, - state.pop_nodes, - state.pop_cons, - state.species_info, - state.idx2species, - state.center_nodes, - state.center_cons, - state.generation, - state.next_node_key, - state.next_species_key, - state.jit_config - ) - - return evox.State( - randkey=randkey, - pop_nodes=pop_nodes, - pop_cons=pop_cons, - species_info=species_info, - idx2species=idx2species, - center_nodes=center_nodes, - center_cons=center_cons, - generation=generation, - next_node_key=next_node_key, - next_species_key=next_species_key, - jit_config=state.jit_config - ) diff --git a/examples/debug.py b/examples/debug.py deleted file mode 100644 index 1a9f14a..0000000 --- a/examples/debug.py +++ /dev/null @@ -1,115 +0,0 @@ -import pickle - -import jax -from jax import numpy as jnp, jit, vmap - -import numpy as np - -from configs import Configer -from algorithms.neat import initialize_genomes -from algorithms.neat import tell -from algorithms.neat import unflatten_connections, topological_sort, create_forward_function - -jax.config.update("jax_disable_jit", True) - -xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) -xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) - -def evaluate(forward_func): - """ - :param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size) - :return: - """ - outs = forward_func(xor_inputs) - outs = jax.device_get(outs) - fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) - return fitnesses - - -def get_fitnesses(pop_nodes, pop_cons, pop_unflatten_connections, pop_topological_sort, forward_func): - u_pop_cons = pop_unflatten_connections(pop_nodes, pop_cons) - pop_seqs = pop_topological_sort(pop_nodes, u_pop_cons) - func = lambda x: forward_func(x, pop_seqs, pop_nodes, u_pop_cons) - - return evaluate(func) - - -def equal(ar1, ar2): - if ar1.shape != ar2.shape: - return False - - nan_mask1 = jnp.isnan(ar1) - nan_mask2 = jnp.isnan(ar2) - - return jnp.all((ar1 == ar2) | (nan_mask1 & nan_mask2)) - -def main(): - # initialize - config = Configer.load_config("xor.ini") - jit_config = Configer.create_jit_config(config) # config used in jit-able functions - - P = config['pop_size'] - N = config['init_maximum_nodes'] - C = config['init_maximum_connections'] - S = config['init_maximum_species'] - randkey = jax.random.PRNGKey(6) - np.random.seed(6) - - pop_nodes, pop_cons = initialize_genomes(N, C, config) - species_info = np.full((S, 4), np.nan) - species_info[0, :] = 0, -np.inf, 0, P - idx2species = np.zeros(P, dtype=np.float32) - center_nodes = np.full((S, N, 5), np.nan) - center_cons = np.full((S, C, 4), np.nan) - center_nodes[0, :, :] = pop_nodes[0, :, :] - center_cons[0, :, :] = pop_cons[0, :, :] - generation = 0 - - pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons = jax.device_put( - [pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons]) - - pop_unflatten_connections = jit(vmap(unflatten_connections)) - pop_topological_sort = jit(vmap(topological_sort)) - forward = create_forward_function(config) - - - while True: - fitness = get_fitnesses(pop_nodes, pop_cons, pop_unflatten_connections, pop_topological_sort, forward) - - last_max = np.max(fitness) - - info = [fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation, - jit_config] - - with open('list.pkl', 'wb') as f: - # 使用pickle模块的dump函数来保存list - pickle.dump(info, f) - - randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation = tell( - fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation, - jit_config) - - fitness = get_fitnesses(pop_nodes, pop_cons, pop_unflatten_connections, pop_topological_sort, forward) - current_max = np.max(fitness) - print(last_max, current_max) - assert current_max >= last_max, f"current_max: {current_max}, last_max: {last_max}" - - -if __name__ == '__main__': - # main() - config = Configer.load_config("xor.ini") - pop_unflatten_connections = jit(vmap(unflatten_connections)) - pop_topological_sort = jit(vmap(topological_sort)) - forward = create_forward_function(config) - - with open('list.pkl', 'rb') as f: - # 使用pickle模块的dump函数来保存list - fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, i, jit_config = pickle.load( - f) - - print(np.max(fitness)) - randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, _ = tell( - fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, i, - jit_config) - fitness = get_fitnesses(pop_nodes, pop_cons, pop_unflatten_connections, pop_topological_sort, forward) - print(np.max(fitness)) diff --git a/examples/evox_/acrobot.ini b/examples/evox_/acrobot.ini deleted file mode 100644 index f80e61d..0000000 --- a/examples/evox_/acrobot.ini +++ /dev/null @@ -1,22 +0,0 @@ -[basic] -num_inputs = 6 -num_outputs = 3 -maximum_nodes = 50 -maximum_connections = 50 -maximum_species = 10 -forward_way = "single" -random_seed = 42 - -[population] -pop_size = 100 - -[gene-activation] -activation_default = "sigmoid" -activation_option_names = ['sigmoid', 'tanh', 'sin', 'gauss', 'relu', 'identity', 'inv', 'log', 'exp', 'abs', 'hat', 'square'] -activation_replace_rate = 0.1 - -[gene-aggregation] -aggregation_default = "sum" -aggregation_option_names = ['sum', 'product', 'max', 'min', 'maxabs', 'median', 'mean'] -aggregation_replace_rate = 0.1 - diff --git a/examples/evox_/acrobot.py b/examples/evox_/acrobot.py deleted file mode 100644 index d0b7fa8..0000000 --- a/examples/evox_/acrobot.py +++ /dev/null @@ -1,63 +0,0 @@ -import evox -import jax -from jax import jit, vmap, numpy as jnp - -from configs import Configer -from algorithms.neat import create_forward_function, topological_sort, unflatten_connections -from evox_adaptor import NEAT, Gym - -if __name__ == '__main__': - batch_policy = True - key = jax.random.PRNGKey(42) - - monitor = evox.monitors.StdSOMonitor() - neat_config = Configer.load_config('acrobot.ini') - origin_forward_func = create_forward_function(neat_config) - - - def neat_transform(pop): - P = neat_config['pop_size'] - N = neat_config['maximum_nodes'] - C = neat_config['maximum_connections'] - - pop_nodes = pop[:P * N * 5].reshape((P, N, 5)) - pop_cons = pop[P * N * 5:].reshape((P, C, 4)) - - u_pop_cons = vmap(unflatten_connections)(pop_nodes, pop_cons) - pop_seqs = vmap(topological_sort)(pop_nodes, u_pop_cons) - return pop_seqs, pop_nodes, u_pop_cons - - # special policy for mountain car - def neat_forward(genome, x): - res = origin_forward_func(x, *genome) - out = jnp.argmax(res) # {0, 1, 2} - return out - - - forward_func = lambda pop, x: origin_forward_func(x, *pop) - - problem = Gym( - policy=jit(vmap(neat_forward)), - env_name="Acrobot-v1", - env_options={"new_step_api": True}, - pop_size=100, - ) - - # create a pipeline - pipeline = evox.pipelines.StdPipeline( - algorithm=NEAT(neat_config), - problem=problem, - pop_transform=jit(neat_transform), - fitness_transform=monitor.record_fit, - ) - # init the pipeline - state = pipeline.init(key) - - # run the pipeline for 10 steps - for i in range(30): - state = pipeline.step(state) - print(i, monitor.get_min_fitness()) - - # obtain -62.0 - min_fitness = monitor.get_min_fitness() - print(min_fitness) diff --git a/examples/evox_/bipedalwalker.ini b/examples/evox_/bipedalwalker.ini deleted file mode 100644 index 9f271b5..0000000 --- a/examples/evox_/bipedalwalker.ini +++ /dev/null @@ -1,22 +0,0 @@ -[basic] -num_inputs = 24 -num_outputs = 4 -maximum_nodes = 100 -maximum_connections = 200 -maximum_species = 10 -forward_way = "single" -random_seed = 42 - -[population] -pop_size = 100 - -[gene-activation] -activation_default = "sigmoid" -activation_option_names = ['sigmoid', 'tanh', 'sin', 'gauss', 'relu', 'identity', 'inv', 'log', 'exp', 'abs', 'hat', 'square'] -activation_replace_rate = 0.1 - -[gene-aggregation] -aggregation_default = "sum" -aggregation_option_names = ['sum', 'product', 'max', 'min', 'maxabs', 'median', 'mean'] -aggregation_replace_rate = 0.1 - diff --git a/examples/evox_/bipedalwalker.py b/examples/evox_/bipedalwalker.py deleted file mode 100644 index 4abf1f3..0000000 --- a/examples/evox_/bipedalwalker.py +++ /dev/null @@ -1,62 +0,0 @@ -import evox -import jax -from jax import jit, vmap, numpy as jnp - -from configs import Configer -from algorithms.neat import create_forward_function, topological_sort, unflatten_connections -from evox_adaptor import NEAT, Gym - -if __name__ == '__main__': - batch_policy = True - key = jax.random.PRNGKey(42) - - monitor = evox.monitors.StdSOMonitor() - neat_config = Configer.load_config('bipedalwalker.ini') - origin_forward_func = create_forward_function(neat_config) - - - def neat_transform(pop): - P = neat_config['pop_size'] - N = neat_config['maximum_nodes'] - C = neat_config['maximum_connections'] - - pop_nodes = pop[:P * N * 5].reshape((P, N, 5)) - pop_cons = pop[P * N * 5:].reshape((P, C, 4)) - - u_pop_cons = vmap(unflatten_connections)(pop_nodes, pop_cons) - pop_seqs = vmap(topological_sort)(pop_nodes, u_pop_cons) - return pop_seqs, pop_nodes, u_pop_cons - - # special policy for mountain car - def neat_forward(genome, x): - res = origin_forward_func(x, *genome) - out = jnp.tanh(res) # (-1, 1) - return out - - - forward_func = lambda pop, x: origin_forward_func(x, *pop) - - problem = Gym( - policy=jit(vmap(neat_forward)), - env_name="BipedalWalker-v3", - pop_size=100, - ) - - # create a pipeline - pipeline = evox.pipelines.StdPipeline( - algorithm=NEAT(neat_config), - problem=problem, - pop_transform=jit(neat_transform), - fitness_transform=monitor.record_fit, - ) - # init the pipeline - state = pipeline.init(key) - - # run the pipeline for 10 steps - for i in range(30): - state = pipeline.step(state) - print(i, monitor.get_min_fitness()) - - # obtain 98.91529684268514 - min_fitness = monitor.get_min_fitness() - print(min_fitness) diff --git a/examples/evox_/cartpole.ini b/examples/evox_/cartpole.ini deleted file mode 100644 index a5ba7e4..0000000 --- a/examples/evox_/cartpole.ini +++ /dev/null @@ -1,11 +0,0 @@ -[basic] -num_inputs = 4 -num_outputs = 1 -maximum_nodes = 50 -maximum_connections = 50 -maximum_species = 10 -forward_way = "single" -random_seed = 42 - -[population] -pop_size = 40 \ No newline at end of file diff --git a/examples/evox_/cartpole.py b/examples/evox_/cartpole.py deleted file mode 100644 index 54c73ec..0000000 --- a/examples/evox_/cartpole.py +++ /dev/null @@ -1,63 +0,0 @@ -import evox -import jax -from jax import jit, vmap, numpy as jnp - -from configs import Configer -from algorithms.neat import create_forward_function, topological_sort, unflatten_connections -from evox_adaptor import NEAT, Gym - -if __name__ == '__main__': - batch_policy = True - key = jax.random.PRNGKey(42) - - monitor = evox.monitors.StdSOMonitor() - neat_config = Configer.load_config('cartpole.ini') - origin_forward_func = create_forward_function(neat_config) - - - def neat_transform(pop): - P = neat_config['pop_size'] - N = neat_config['maximum_nodes'] - C = neat_config['maximum_connections'] - - pop_nodes = pop[:P * N * 5].reshape((P, N, 5)) - pop_cons = pop[P * N * 5:].reshape((P, C, 4)) - - u_pop_cons = vmap(unflatten_connections)(pop_nodes, pop_cons) - pop_seqs = vmap(topological_sort)(pop_nodes, u_pop_cons) - return pop_seqs, pop_nodes, u_pop_cons - - # special policy for cartpole - def neat_forward(genome, x): - res = origin_forward_func(x, *genome)[0] - out = jnp.where(res > 0.5, 1, 0) - return out - - - forward_func = lambda pop, x: origin_forward_func(x, *pop) - - problem = Gym( - policy=jit(vmap(neat_forward)), - env_name="CartPole-v1", - env_options={"new_step_api": True}, - pop_size=40, - ) - - # create a pipeline - pipeline = evox.pipelines.StdPipeline( - algorithm=NEAT(neat_config), - problem=problem, - pop_transform=jit(neat_transform), - fitness_transform=monitor.record_fit, - ) - # init the pipeline - state = pipeline.init(key) - - # run the pipeline for 10 steps - for i in range(10): - state = pipeline.step(state) - print(monitor.get_min_fitness()) - - # obtain 500 - min_fitness = monitor.get_min_fitness() - print(min_fitness) diff --git a/examples/evox_/gym_env_test.py b/examples/evox_/gym_env_test.py deleted file mode 100644 index ef5cefa..0000000 --- a/examples/evox_/gym_env_test.py +++ /dev/null @@ -1,14 +0,0 @@ -import gym - -env = gym.make("CartPole-v1", new_step_api=True) -print(env.reset()) -obs = env.reset() - -print(obs) -while True: - action = env.action_space.sample() - obs, reward, terminate, truncate, info = env.step(action) - print(obs, info) - if terminate | truncate: - break - diff --git a/examples/evox_/mountain_car.ini b/examples/evox_/mountain_car.ini deleted file mode 100644 index 21cb7d8..0000000 --- a/examples/evox_/mountain_car.ini +++ /dev/null @@ -1,22 +0,0 @@ -[basic] -num_inputs = 2 -num_outputs = 1 -maximum_nodes = 50 -maximum_connections = 50 -maximum_species = 10 -forward_way = "single" -random_seed = 42 - -[population] -pop_size = 100 - -[gene-activation] -activation_default = "sigmoid" -activation_option_names = ['sigmoid', 'tanh', 'sin', 'gauss', 'relu', 'identity', 'inv', 'log', 'exp', 'abs', 'hat', 'square'] -activation_replace_rate = 0.1 - -[gene-aggregation] -aggregation_default = "sum" -aggregation_option_names = ['sum', 'product', 'max', 'min', 'maxabs', 'median', 'mean'] -aggregation_replace_rate = 0.1 - diff --git a/examples/evox_/mountain_car.py b/examples/evox_/mountain_car.py deleted file mode 100644 index 9d8bea1..0000000 --- a/examples/evox_/mountain_car.py +++ /dev/null @@ -1,63 +0,0 @@ -import evox -import jax -from jax import jit, vmap, numpy as jnp - -from configs import Configer -from algorithms.neat import create_forward_function, topological_sort, unflatten_connections -from evox_adaptor import NEAT, Gym - -if __name__ == '__main__': - batch_policy = True - key = jax.random.PRNGKey(42) - - monitor = evox.monitors.StdSOMonitor() - neat_config = Configer.load_config('mountain_car.ini') - origin_forward_func = create_forward_function(neat_config) - - - def neat_transform(pop): - P = neat_config['pop_size'] - N = neat_config['maximum_nodes'] - C = neat_config['maximum_connections'] - - pop_nodes = pop[:P * N * 5].reshape((P, N, 5)) - pop_cons = pop[P * N * 5:].reshape((P, C, 4)) - - u_pop_cons = vmap(unflatten_connections)(pop_nodes, pop_cons) - pop_seqs = vmap(topological_sort)(pop_nodes, u_pop_cons) - return pop_seqs, pop_nodes, u_pop_cons - - # special policy for mountain car - def neat_forward(genome, x): - res = origin_forward_func(x, *genome) - out = jnp.tanh(res) # (-1, 1) - return out - - - forward_func = lambda pop, x: origin_forward_func(x, *pop) - - problem = Gym( - policy=jit(vmap(neat_forward)), - env_name="MountainCarContinuous-v0", - env_options={"new_step_api": True}, - pop_size=100, - ) - - # create a pipeline - pipeline = evox.pipelines.StdPipeline( - algorithm=NEAT(neat_config), - problem=problem, - pop_transform=jit(neat_transform), - fitness_transform=monitor.record_fit, - ) - # init the pipeline - state = pipeline.init(key) - - # run the pipeline for 10 steps - for i in range(30): - state = pipeline.step(state) - print(i, monitor.get_min_fitness()) - - # obtain 98.91529684268514 - min_fitness = monitor.get_min_fitness() - print(min_fitness) diff --git a/examples/jax_playground.py b/examples/jax_playground.py deleted file mode 100644 index fc22005..0000000 --- a/examples/jax_playground.py +++ /dev/null @@ -1,18 +0,0 @@ -from functools import partial - -from jax import numpy as jnp, jit - - -@partial(jit, static_argnames=['reverse']) -def rank_element(array, reverse=False): - """ - rank the element in the array. - if reverse is True, the rank is from large to small. - """ - if reverse: - array = -array - return jnp.argsort(jnp.argsort(array)) - - -a = jnp.array([1, 5, 3, 5, 2, 1, 0]) -print(rank_element(a, reverse=True)) diff --git a/examples/state_test.py b/examples/state_test.py new file mode 100644 index 0000000..0a28590 --- /dev/null +++ b/examples/state_test.py @@ -0,0 +1,14 @@ +import jax +from algorithm.state import State + +@jax.jit +def func(state: State, a): + return state.update(a=a) + + +state = State(c=1, b=2) +print(state) + +state = func(state, 1111111) + +print(state) diff --git a/examples/xor.ini b/examples/xor.ini deleted file mode 100644 index 893fff7..0000000 --- a/examples/xor.ini +++ /dev/null @@ -1,5 +0,0 @@ -[basic] -forward_way = "common" - -[population] -fitness_threshold = 4 \ No newline at end of file diff --git a/examples/xor.py b/examples/xor.py deleted file mode 100644 index 1d88081..0000000 --- a/examples/xor.py +++ /dev/null @@ -1,31 +0,0 @@ -import jax -import numpy as np - -from configs import Configer -from pipeline import Pipeline - -xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) -xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) - - -def evaluate(forward_func): - """ - :param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size) - :return: - """ - outs = forward_func(xor_inputs) - outs = jax.device_get(outs) - fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) - return fitnesses - - -def main(): - config = Configer.load_config("xor.ini") - pipeline = Pipeline(config) - nodes, cons = pipeline.auto_run(evaluate) - # g = Genome(nodes, cons, config) - # print(g) - - -if __name__ == '__main__': - main() diff --git a/examples/xor3d.ini b/examples/xor3d.ini deleted file mode 100644 index e883a52..0000000 --- a/examples/xor3d.ini +++ /dev/null @@ -1,47 +0,0 @@ -[basic] -num_inputs = 3 -num_outputs = 1 -maximum_nodes = 50 -maximum_connections = 50 -maximum_species = 10 -forward_way = "common" -batch_size = 4 -random_seed = 42 - -[population] -fitness_threshold = 8 -generation_limit = 1000 -fitness_criterion = "max" -pop_size = 10000 - -[genome] -compatibility_disjoint = 1.0 -compatibility_weight = 0.5 -conn_add_prob = 0.4 -conn_add_trials = 1 -conn_delete_prob = 0 -node_add_prob = 0.2 -node_delete_prob = 0 - -[species] -compatibility_threshold = 3 -species_elitism = 1 -max_stagnation = 15 -genome_elitism = 2 -survival_threshold = 0.2 -min_species_size = 1 -spawn_number_move_rate = 0.5 - -[gene-bias] -bias_init_mean = 0.0 -bias_init_std = 1.0 -bias_mutate_power = 0.5 -bias_mutate_rate = 0.7 -bias_replace_rate = 0.1 - -[gene-weight] -weight_init_mean = 0.0 -weight_init_std = 1.0 -weight_mutate_power = 0.5 -weight_mutate_rate = 0.8 -weight_replace_rate = 0.1 diff --git a/examples/xor3d.py b/examples/xor3d.py deleted file mode 100644 index 1fa7c43..0000000 --- a/examples/xor3d.py +++ /dev/null @@ -1,31 +0,0 @@ -import jax -import numpy as np - -from configs import Configer -from pipeline import Pipeline - -xor_inputs = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], dtype=np.float32) -xor_outputs = np.array([[0], [1], [1], [0], [1], [0], [0], [1]], dtype=np.float32) - - -def evaluate(forward_func): - """ - :param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size) - :return: - """ - outs = forward_func(xor_inputs) - outs = jax.device_get(outs) - fitnesses = 8 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) - return fitnesses - - -def main(): - config = Configer.load_config("xor3d.ini") - pipeline = Pipeline(config) - nodes, cons = pipeline.auto_run(evaluate) - # g = Genome(nodes, cons, config) - # print(g) - - -if __name__ == '__main__': - main() diff --git a/pipeline.py b/pipeline.py deleted file mode 100644 index 6d0b18a..0000000 --- a/pipeline.py +++ /dev/null @@ -1,158 +0,0 @@ -import time -from typing import Union, Callable - -import numpy as np -import jax -from jax import jit, vmap - -from algorithms import neat -from configs.configer import Configer - - -class Pipeline: - """ - Neat algorithm pipeline. - """ - - def __init__(self, config): - self.config = config # global config - self.jit_config = Configer.create_jit_config(config) - - self.best_genome = None - - self.neat_states = neat.initialize(config) - - self.best_fitness = float('-inf') - self.generation_timestamp = time.time() - - self.evaluate_time = 0 - - ( - self.randkey, - self.pop_nodes, - self.pop_cons, - self.species_info, - self.idx2species, - self.center_nodes, - self.center_cons, - self.generation, - self.next_node_key, - self.next_species_key, - ) = neat.initialize(config) - - self.forward = neat.create_forward_function(config) - self.pop_unflatten_connections = jit(vmap(neat.unflatten_connections)) - self.pop_topological_sort = jit(vmap(neat.topological_sort)) - - def ask(self): - """ - Creates a function that receives a genome and returns a forward function. - There are 3 types of config['forward_way']: {'single', 'pop', 'common'} - - single: - Create pop_size number of forward functions. - Each function receive (input_size, ) and returns (output_size, ) - e.g. RL task - - batch: - Create pop_size number of forward functions. - Each function receive (input_size, ) and returns (output_size, ) - some task need to calculate the fitness of a batch of inputs - - pop: - Create a single forward function, which use only once calculation for the population. - The function receives (pop_size, batch_size, input_size) and returns (pop_size, batch_size, output_size) - - common: - Special case of pop. The population has the same inputs. - The function receives (batch_size, input_size) and returns (pop_size, batch_size, output_size) - e.g. numerical regression; Hyper-NEAT - - """ - u_pop_cons = self.pop_unflatten_connections(self.pop_nodes, self.pop_cons) - pop_seqs = self.pop_topological_sort(self.pop_nodes, u_pop_cons) - - # only common mode is supported currently - if self.config['forward_way'] == 'single' or self.config['forward_way'] == 'batch': - # carry data to cpu for fast iteration - pop_seqs, self.pop_nodes, self.pop_cons = jax.device_get((pop_seqs, self.pop_nodes, self.pop_cons)) - funcs = [lambda x: self.forward(x, seqs, nodes, u_cons) - for seqs, nodes, u_cons in zip(pop_seqs, self.pop_nodes, self.pop_cons)] - return funcs - - elif self.config['forward_way'] == 'pop' or self.config['forward_way'] == 'common': - return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons) - - else: - raise NotImplementedError(f"forward_way {self.config['forward_way']} is not supported") - - def tell(self, fitness): - ( - self.randkey, - self.pop_nodes, - self.pop_cons, - self.species_info, - self.idx2species, - self.center_nodes, - self.center_cons, - self.generation, - self.next_node_key, - self.next_species_key, - ) = neat.tell( - fitness, - self.randkey, - self.pop_nodes, - self.pop_cons, - self.species_info, - self.idx2species, - self.center_nodes, - self.center_cons, - self.generation, - self.next_node_key, - self.next_species_key, - self.jit_config - ) - - def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"): - for _ in range(self.config['generation_limit']): - forward_func = self.ask() - - tic = time.time() - fitnesses = fitness_func(forward_func) - self.evaluate_time += time.time() - tic - - # assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!" - - if analysis is not None: - if analysis == "default": - self.default_analysis(fitnesses) - else: - assert callable(analysis), f"What the fuck you passed in? A {analysis}?" - analysis(fitnesses) - - if max(fitnesses) >= self.config['fitness_threshold']: - print("Fitness limit reached!") - return self.best_genome - - self.tell(fitnesses) - print("Generation limit reached!") - return self.best_genome - - def default_analysis(self, fitnesses): - max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) - - new_timestamp = time.time() - cost_time = new_timestamp - self.generation_timestamp - self.generation_timestamp = new_timestamp - - max_idx = np.argmax(fitnesses) - if fitnesses[max_idx] > self.best_fitness: - self.best_fitness = fitnesses[max_idx] - self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx]) - - member_count = jax.device_get(self.species_info[:, 3]) - species_sizes = [int(i) for i in member_count if i > 0] - - print(f"Generation: {self.generation}", - f"species: {len(species_sizes)}, {species_sizes}", - f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}")