From 40cf0b6fbe9e00e9a93431d731966dcec35e709e Mon Sep 17 00:00:00 2001 From: wls2002 Date: Mon, 17 Jul 2023 19:59:46 +0800 Subject: [PATCH] change a lot --- algorithm/neat/NEAT.py | 10 +++- algorithm/neat/gene/base.py | 8 +-- algorithm/neat/gene/normal.py | 94 +++++++++++++++++++++++++++--- algorithm/neat/genome/__init__.py | 2 + algorithm/neat/genome/crossover.py | 68 +++++++++++++++++++++ algorithm/neat/genome/distance.py | 76 ++++++++++++++++++++++++ algorithm/neat/genome/mutate.py | 1 - examples/xor_test.py | 7 ++- 8 files changed, 248 insertions(+), 18 deletions(-) diff --git a/algorithm/neat/NEAT.py b/algorithm/neat/NEAT.py index ff5baa9..f9a9343 100644 --- a/algorithm/neat/NEAT.py +++ b/algorithm/neat/NEAT.py @@ -1,6 +1,8 @@ +import jax + from algorithm.state import State from .gene import * -from .genome import initialize_genomes +from .genome import initialize_genomes, create_mutate, create_distance, crossover class NEAT: @@ -11,6 +13,10 @@ class NEAT: else: raise NotImplementedError + self.mutate = jax.jit(create_mutate(config, self.gene_type)) + self.distance = jax.jit(create_distance(config, self.gene_type)) + self.crossover = jax.jit(crossover) + def setup(self, randkey): state = State( @@ -25,6 +31,8 @@ class NEAT: output_idx=self.config['output_idx'] ) + state = self.gene_type.setup(state, self.config) + pop_nodes, pop_conns = initialize_genomes(state, self.gene_type) next_node_key = max(*state.input_idx, *state.output_idx) + 2 state = state.update( diff --git a/algorithm/neat/gene/base.py b/algorithm/neat/gene/base.py index 40d2c10..7a710f2 100644 --- a/algorithm/neat/gene/base.py +++ b/algorithm/neat/gene/base.py @@ -26,12 +26,12 @@ class BaseGene: return attrs @staticmethod - def distance_node(state, array: Array): - return array + def distance_node(state, array1: Array, array2: Array): + return array1 @staticmethod - def distance_conn(state, array: Array): - return array + def distance_conn(state, array1: Array, array2: Array): + return array1 @staticmethod def forward(state, array: Array): diff --git a/algorithm/neat/gene/normal.py b/algorithm/neat/gene/normal.py index 799012e..7468671 100644 --- a/algorithm/neat/gene/normal.py +++ b/algorithm/neat/gene/normal.py @@ -1,3 +1,4 @@ +import jax from jax import Array, numpy as jnp from . import BaseGene @@ -9,32 +10,107 @@ class NormalGene(BaseGene): @staticmethod def setup(state, config): - return state + return state.update( + bias_init_mean=config['bias_init_mean'], + bias_init_std=config['bias_init_std'], + bias_mutate_power=config['bias_mutate_power'], + bias_mutate_rate=config['bias_mutate_rate'], + bias_replace_rate=config['bias_replace_rate'], + + response_init_mean=config['response_init_mean'], + response_init_std=config['response_init_std'], + response_mutate_power=config['response_mutate_power'], + response_mutate_rate=config['response_mutate_rate'], + response_replace_rate=config['response_replace_rate'], + + activation_default=config['activation_default'], + activation_options=config['activation_options'], + activation_replace_rate=config['activation_replace_rate'], + + aggregation_default=config['aggregation_default'], + aggregation_options=config['aggregation_options'], + aggregation_replace_rate=config['aggregation_replace_rate'], + + weight_init_mean=config['weight_init_mean'], + weight_init_std=config['weight_init_std'], + weight_mutate_power=config['weight_mutate_power'], + weight_mutate_rate=config['weight_mutate_rate'], + weight_replace_rate=config['weight_replace_rate'], + ) @staticmethod def new_node_attrs(state): - return jnp.array([0, 0, 0, 0]) + return jnp.array([state.bias_init_mean, state.response_init_mean, + state.activation_default, state.aggregation_default]) @staticmethod def new_conn_attrs(state): - return jnp.array([0]) + return jnp.array([state.weight_init_mean]) @staticmethod def mutate_node(state, attrs: Array, key): - return attrs + k1, k2, k3, k4 = jax.random.split(key, num=4) + + bias = NormalGene._mutate_float(k1, attrs[0], state.bias_init_mean, state.bias_init_std, + state.bias_mutate_power, state.bias_mutate_rate, state.bias_replace_rate) + res = NormalGene._mutate_float(k2, attrs[1], state.response_init_mean, state.response_init_std, + state.response_mutate_power, state.response_mutate_rate, + state.response_replace_rate) + act = NormalGene._mutate_int(k3, attrs[2], state.activation_options, state.activation_replace_rate) + agg = NormalGene._mutate_int(k4, attrs[3], state.aggregation_options, state.aggregation_replace_rate) + + return jnp.array([bias, res, act, agg]) @staticmethod def mutate_conn(state, attrs: Array, key): - return attrs + weight = NormalGene._mutate_float(key, attrs[0], state.weight_init_mean, state.weight_init_std, + state.weight_mutate_power, state.weight_mutate_rate, + state.weight_replace_rate) + + return jnp.array([weight]) @staticmethod - def distance_node(state, array: Array): - return array + def distance_node(state, array1: Array, array2: Array): + # bias + response + activation + aggregation + return jnp.abs(array1[1] - array2[1]) + jnp.abs(array1[2] - array2[2]) + \ + (array1[3] != array2[3]) + (array1[4] != array2[4]) @staticmethod - def distance_conn(state, array: Array): - return array + def distance_conn(state, array1: Array, array2: Array): + return (array1[2] != array2[2]) + jnp.abs(array1[3] - array2[3]) # enable + weight @staticmethod def forward(state, array: Array): return array + + @staticmethod + def _mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate): + k1, k2, k3 = jax.random.split(key, num=3) + noise = jax.random.normal(k1, ()) * mutate_power + replace = jax.random.normal(k2, ()) * init_std + init_mean + r = jax.random.uniform(k3, ()) + + val = jnp.where( + r < mutate_rate, + val + noise, + jnp.where( + (mutate_rate < r) & (r < mutate_rate + replace_rate), + replace, + val + ) + ) + + return val + + @staticmethod + def _mutate_int(key, val, options, replace_rate): + k1, k2 = jax.random.split(key, num=2) + r = jax.random.uniform(k1, ()) + + val = jnp.where( + r < replace_rate, + jax.random.choice(k2, options), + val + ) + + return val diff --git a/algorithm/neat/genome/__init__.py b/algorithm/neat/genome/__init__.py index 1cc30c1..1b859ae 100644 --- a/algorithm/neat/genome/__init__.py +++ b/algorithm/neat/genome/__init__.py @@ -1,2 +1,4 @@ from .basic import initialize_genomes from .mutate import create_mutate +from .distance import create_distance +from .crossover import crossover diff --git a/algorithm/neat/genome/crossover.py b/algorithm/neat/genome/crossover.py index e69de29..44ce594 100644 --- a/algorithm/neat/genome/crossover.py +++ b/algorithm/neat/genome/crossover.py @@ -0,0 +1,68 @@ +from typing import Tuple + +import jax +from jax import jit, Array, numpy as jnp + + +def crossover(state, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array): + """ + use genome1 and genome2 to generate a new genome + notice that genome1 should have higher fitness than genome2 (genome1 is winner!) + """ + randkey_1, randkey_2, key= jax.random.split(state.randkey, 3) + + # crossover nodes + keys1, keys2 = nodes1[:, 0], nodes2[:, 0] + # make homologous genes align in nodes2 align with nodes1 + nodes2 = align_array(keys1, keys2, nodes2, False) + + # For not homologous genes, use the value of nodes1(winner) + # For homologous genes, use the crossover result between nodes1 and nodes2 + new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2)) + + # crossover connections + con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2] + cons2 = align_array(con_keys1, con_keys2, cons2, True) + new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2)) + + return state.update(randkey=key), new_nodes, new_cons + + +def align_array(seq1: Array, seq2: Array, ar2: Array, is_conn: bool) -> Array: + """ + After I review this code, I found that it is the most difficult part of the code. Please never change it! + make ar2 align with ar1. + :param seq1: + :param seq2: + :param ar2: + :param is_conn: + :return: + align means to intersect part of ar2 will be at the same position as ar1, + non-intersect part of ar2 will be set to Nan + """ + seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :] + mask = (seq1 == seq2) & (~jnp.isnan(seq1)) + + if is_conn: + mask = jnp.all(mask, axis=2) + + intersect_mask = mask.any(axis=1) + idx = jnp.arange(0, len(seq1)) + idx_fixed = jnp.dot(mask, idx) + + refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan) + + return refactor_ar2 + + +def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array: + """ + crossover two genes + :param rand_key: + :param g1: + :param g2: + :return: + only gene with the same key will be crossover, thus don't need to consider change key + """ + r = jax.random.uniform(rand_key, shape=g1.shape) + return jnp.where(r > 0.5, g1, g2) diff --git a/algorithm/neat/genome/distance.py b/algorithm/neat/genome/distance.py index e69de29..0bd4e5a 100644 --- a/algorithm/neat/genome/distance.py +++ b/algorithm/neat/genome/distance.py @@ -0,0 +1,76 @@ +from typing import Dict, Type + +from jax import Array, numpy as jnp, vmap + +from ..gene import BaseGene + + +def create_distance(config: Dict, gene_type: Type[BaseGene]): + def node_distance(state, nodes1: Array, nodes2: Array): + """ + Calculate the distance between nodes of two genomes. + """ + # statistics nodes count of two genomes + node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0])) + node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0])) + max_cnt = jnp.maximum(node_cnt1, node_cnt2) + + # align homologous nodes + # this process is similar to np.intersect1d. + nodes = jnp.concatenate((nodes1, nodes2), axis=0) + keys = nodes[:, 0] + sorted_indices = jnp.argsort(keys, axis=0) + nodes = nodes[sorted_indices] + nodes = jnp.concatenate([nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0) # add a nan row to the end + fr, sr = nodes[:-1], nodes[1:] # first row, second row + + # flag location of homologous nodes + intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0]) + + # calculate the count of non_homologous of two genomes + non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask) + + # calculate the distance of homologous nodes + hnd = vmap(gene_type.distance_node, in_axes=(None, 0, 0))(state, fr, sr) + hnd = jnp.where(jnp.isnan(hnd), 0, hnd) + homologous_distance = jnp.sum(hnd * intersect_mask) + + val = non_homologous_cnt * config['compatibility_disjoint'] + homologous_distance * config[ + 'compatibility_weight'] + + return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division + + def connection_distance(state, cons1: Array, cons2: Array): + """ + Calculate the distance between connections of two genomes. + Similar process as node_distance. + """ + con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0])) + con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0])) + max_cnt = jnp.maximum(con_cnt1, con_cnt2) + + cons = jnp.concatenate((cons1, cons2), axis=0) + keys = cons[:, :2] + sorted_indices = jnp.lexsort(keys.T[::-1]) + cons = cons[sorted_indices] + cons = jnp.concatenate([cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0) # add a nan row to the end + fr, sr = cons[:-1], cons[1:] # first row, second row + + # both genome has such connection + intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0]) + + non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask) + hcd = vmap(gene_type.distance_conn, in_axes=(None, 0, 0))(state, fr, sr) + hcd = jnp.where(jnp.isnan(hcd), 0, hcd) + homologous_distance = jnp.sum(hcd * intersect_mask) + + val = non_homologous_cnt * config['compatibility_disjoint'] + homologous_distance * config[ + 'compatibility_weight'] + + return jnp.where(max_cnt == 0, 0, val / max_cnt) + + def distance(state, nodes1, conns1, nodes2, conns2): + return node_distance(state, nodes1, nodes2) + connection_distance(state, conns1, conns2) + + return distance + diff --git a/algorithm/neat/genome/mutate.py b/algorithm/neat/genome/mutate.py index 748f8be..98e9d93 100644 --- a/algorithm/neat/genome/mutate.py +++ b/algorithm/neat/genome/mutate.py @@ -1,6 +1,5 @@ from typing import Dict, Tuple, Type -import numpy as np import jax from jax import Array, numpy as jnp, vmap diff --git a/examples/xor_test.py b/examples/xor_test.py index 46dcf8e..bb2931a 100644 --- a/examples/xor_test.py +++ b/examples/xor_test.py @@ -2,15 +2,16 @@ import jax from algorithm.config import Configer from algorithm.neat import NEAT -from algorithm.neat.genome import create_mutate if __name__ == '__main__': config = Configer.load_config() neat = NEAT(config) randkey = jax.random.PRNGKey(42) state = neat.setup(randkey) - mutate_func = jax.jit(create_mutate(config, neat.gene_type)) - state = mutate_func(state) + state = neat.mutate(state) print(state) + pop_nodes, pop_conns = state.pop_nodes, state.pop_conns + print(neat.distance(state, pop_nodes[0], pop_conns[0], pop_nodes[1], pop_conns[1])) + print(neat.crossover(state, pop_nodes[0], pop_conns[0], pop_nodes[1], pop_conns[1]))