change a lot

This commit is contained in:
wls2002
2023-07-17 19:59:46 +08:00
parent f4763ebcea
commit 40cf0b6fbe
8 changed files with 248 additions and 18 deletions

View File

@@ -1,6 +1,8 @@
import jax
from algorithm.state import State from algorithm.state import State
from .gene import * from .gene import *
from .genome import initialize_genomes from .genome import initialize_genomes, create_mutate, create_distance, crossover
class NEAT: class NEAT:
@@ -11,6 +13,10 @@ class NEAT:
else: else:
raise NotImplementedError raise NotImplementedError
self.mutate = jax.jit(create_mutate(config, self.gene_type))
self.distance = jax.jit(create_distance(config, self.gene_type))
self.crossover = jax.jit(crossover)
def setup(self, randkey): def setup(self, randkey):
state = State( state = State(
@@ -25,6 +31,8 @@ class NEAT:
output_idx=self.config['output_idx'] output_idx=self.config['output_idx']
) )
state = self.gene_type.setup(state, self.config)
pop_nodes, pop_conns = initialize_genomes(state, self.gene_type) pop_nodes, pop_conns = initialize_genomes(state, self.gene_type)
next_node_key = max(*state.input_idx, *state.output_idx) + 2 next_node_key = max(*state.input_idx, *state.output_idx) + 2
state = state.update( state = state.update(

View File

@@ -26,12 +26,12 @@ class BaseGene:
return attrs return attrs
@staticmethod @staticmethod
def distance_node(state, array: Array): def distance_node(state, array1: Array, array2: Array):
return array return array1
@staticmethod @staticmethod
def distance_conn(state, array: Array): def distance_conn(state, array1: Array, array2: Array):
return array return array1
@staticmethod @staticmethod
def forward(state, array: Array): def forward(state, array: Array):

View File

@@ -1,3 +1,4 @@
import jax
from jax import Array, numpy as jnp from jax import Array, numpy as jnp
from . import BaseGene from . import BaseGene
@@ -9,32 +10,107 @@ class NormalGene(BaseGene):
@staticmethod @staticmethod
def setup(state, config): def setup(state, config):
return state return state.update(
bias_init_mean=config['bias_init_mean'],
bias_init_std=config['bias_init_std'],
bias_mutate_power=config['bias_mutate_power'],
bias_mutate_rate=config['bias_mutate_rate'],
bias_replace_rate=config['bias_replace_rate'],
response_init_mean=config['response_init_mean'],
response_init_std=config['response_init_std'],
response_mutate_power=config['response_mutate_power'],
response_mutate_rate=config['response_mutate_rate'],
response_replace_rate=config['response_replace_rate'],
activation_default=config['activation_default'],
activation_options=config['activation_options'],
activation_replace_rate=config['activation_replace_rate'],
aggregation_default=config['aggregation_default'],
aggregation_options=config['aggregation_options'],
aggregation_replace_rate=config['aggregation_replace_rate'],
weight_init_mean=config['weight_init_mean'],
weight_init_std=config['weight_init_std'],
weight_mutate_power=config['weight_mutate_power'],
weight_mutate_rate=config['weight_mutate_rate'],
weight_replace_rate=config['weight_replace_rate'],
)
@staticmethod @staticmethod
def new_node_attrs(state): def new_node_attrs(state):
return jnp.array([0, 0, 0, 0]) return jnp.array([state.bias_init_mean, state.response_init_mean,
state.activation_default, state.aggregation_default])
@staticmethod @staticmethod
def new_conn_attrs(state): def new_conn_attrs(state):
return jnp.array([0]) return jnp.array([state.weight_init_mean])
@staticmethod @staticmethod
def mutate_node(state, attrs: Array, key): def mutate_node(state, attrs: Array, key):
return attrs k1, k2, k3, k4 = jax.random.split(key, num=4)
bias = NormalGene._mutate_float(k1, attrs[0], state.bias_init_mean, state.bias_init_std,
state.bias_mutate_power, state.bias_mutate_rate, state.bias_replace_rate)
res = NormalGene._mutate_float(k2, attrs[1], state.response_init_mean, state.response_init_std,
state.response_mutate_power, state.response_mutate_rate,
state.response_replace_rate)
act = NormalGene._mutate_int(k3, attrs[2], state.activation_options, state.activation_replace_rate)
agg = NormalGene._mutate_int(k4, attrs[3], state.aggregation_options, state.aggregation_replace_rate)
return jnp.array([bias, res, act, agg])
@staticmethod @staticmethod
def mutate_conn(state, attrs: Array, key): def mutate_conn(state, attrs: Array, key):
return attrs weight = NormalGene._mutate_float(key, attrs[0], state.weight_init_mean, state.weight_init_std,
state.weight_mutate_power, state.weight_mutate_rate,
state.weight_replace_rate)
return jnp.array([weight])
@staticmethod @staticmethod
def distance_node(state, array: Array): def distance_node(state, array1: Array, array2: Array):
return array # bias + response + activation + aggregation
return jnp.abs(array1[1] - array2[1]) + jnp.abs(array1[2] - array2[2]) + \
(array1[3] != array2[3]) + (array1[4] != array2[4])
@staticmethod @staticmethod
def distance_conn(state, array: Array): def distance_conn(state, array1: Array, array2: Array):
return array return (array1[2] != array2[2]) + jnp.abs(array1[3] - array2[3]) # enable + weight
@staticmethod @staticmethod
def forward(state, array: Array): def forward(state, array: Array):
return array return array
@staticmethod
def _mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate):
k1, k2, k3 = jax.random.split(key, num=3)
noise = jax.random.normal(k1, ()) * mutate_power
replace = jax.random.normal(k2, ()) * init_std + init_mean
r = jax.random.uniform(k3, ())
val = jnp.where(
r < mutate_rate,
val + noise,
jnp.where(
(mutate_rate < r) & (r < mutate_rate + replace_rate),
replace,
val
)
)
return val
@staticmethod
def _mutate_int(key, val, options, replace_rate):
k1, k2 = jax.random.split(key, num=2)
r = jax.random.uniform(k1, ())
val = jnp.where(
r < replace_rate,
jax.random.choice(k2, options),
val
)
return val

View File

@@ -1,2 +1,4 @@
from .basic import initialize_genomes from .basic import initialize_genomes
from .mutate import create_mutate from .mutate import create_mutate
from .distance import create_distance
from .crossover import crossover

View File

@@ -0,0 +1,68 @@
from typing import Tuple
import jax
from jax import jit, Array, numpy as jnp
def crossover(state, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array):
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
"""
randkey_1, randkey_2, key= jax.random.split(state.randkey, 3)
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
# make homologous genes align in nodes2 align with nodes1
nodes2 = align_array(keys1, keys2, nodes2, False)
# For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
cons2 = align_array(con_keys1, con_keys2, cons2, True)
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
return state.update(randkey=key), new_nodes, new_cons
def align_array(seq1: Array, seq2: Array, ar2: Array, is_conn: bool) -> Array:
"""
After I review this code, I found that it is the most difficult part of the code. Please never change it!
make ar2 align with ar1.
:param seq1:
:param seq2:
:param ar2:
:param is_conn:
:return:
align means to intersect part of ar2 will be at the same position as ar1,
non-intersect part of ar2 will be set to Nan
"""
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
if is_conn:
mask = jnp.all(mask, axis=2)
intersect_mask = mask.any(axis=1)
idx = jnp.arange(0, len(seq1))
idx_fixed = jnp.dot(mask, idx)
refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan)
return refactor_ar2
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
"""
crossover two genes
:param rand_key:
:param g1:
:param g2:
:return:
only gene with the same key will be crossover, thus don't need to consider change key
"""
r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2)

View File

@@ -0,0 +1,76 @@
from typing import Dict, Type
from jax import Array, numpy as jnp, vmap
from ..gene import BaseGene
def create_distance(config: Dict, gene_type: Type[BaseGene]):
def node_distance(state, nodes1: Array, nodes2: Array):
"""
Calculate the distance between nodes of two genomes.
"""
# statistics nodes count of two genomes
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
# align homologous nodes
# this process is similar to np.intersect1d.
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = jnp.concatenate([nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
# flag location of homologous nodes
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
# calculate the count of non_homologous of two genomes
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
# calculate the distance of homologous nodes
hnd = vmap(gene_type.distance_node, in_axes=(None, 0, 0))(state, fr, sr)
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
val = non_homologous_cnt * config['compatibility_disjoint'] + homologous_distance * config[
'compatibility_weight']
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
def connection_distance(state, cons1: Array, cons2: Array):
"""
Calculate the distance between connections of two genomes.
Similar process as node_distance.
"""
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
cons = jnp.concatenate((cons1, cons2), axis=0)
keys = cons[:, :2]
sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate([cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
hcd = vmap(gene_type.distance_conn, in_axes=(None, 0, 0))(state, fr, sr)
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask)
val = non_homologous_cnt * config['compatibility_disjoint'] + homologous_distance * config[
'compatibility_weight']
return jnp.where(max_cnt == 0, 0, val / max_cnt)
def distance(state, nodes1, conns1, nodes2, conns2):
return node_distance(state, nodes1, nodes2) + connection_distance(state, conns1, conns2)
return distance

View File

@@ -1,6 +1,5 @@
from typing import Dict, Tuple, Type from typing import Dict, Tuple, Type
import numpy as np
import jax import jax
from jax import Array, numpy as jnp, vmap from jax import Array, numpy as jnp, vmap

View File

@@ -2,15 +2,16 @@ import jax
from algorithm.config import Configer from algorithm.config import Configer
from algorithm.neat import NEAT from algorithm.neat import NEAT
from algorithm.neat.genome import create_mutate
if __name__ == '__main__': if __name__ == '__main__':
config = Configer.load_config() config = Configer.load_config()
neat = NEAT(config) neat = NEAT(config)
randkey = jax.random.PRNGKey(42) randkey = jax.random.PRNGKey(42)
state = neat.setup(randkey) state = neat.setup(randkey)
mutate_func = jax.jit(create_mutate(config, neat.gene_type)) state = neat.mutate(state)
state = mutate_func(state)
print(state) print(state)
pop_nodes, pop_conns = state.pop_nodes, state.pop_conns
print(neat.distance(state, pop_nodes[0], pop_conns[0], pop_nodes[1], pop_conns[1]))
print(neat.crossover(state, pop_nodes[0], pop_conns[0], pop_nodes[1], pop_conns[1]))