modifying

This commit is contained in:
wls2002
2023-06-19 17:32:34 +08:00
parent 5cbe3c14bb
commit 35b095ba74
6 changed files with 428 additions and 42 deletions

View File

@@ -2,11 +2,15 @@ import os
import warnings import warnings
import configparser import configparser
import numpy as np
from .activations import refactor_act from .activations import refactor_act
from .aggregations import refactor_agg from .aggregations import refactor_agg
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX. # Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
jit_config_keys = [ jit_config_keys = [
"input_idx",
"output_idx",
"compatibility_disjoint", "compatibility_disjoint",
"compatibility_weight", "compatibility_weight",
"conn_add_prob", "conn_add_prob",
@@ -88,10 +92,14 @@ class Configer:
refactor_act(config) refactor_act(config)
refactor_agg(config) refactor_agg(config)
input_idx = np.arange(config['num_inputs'])
output_idx = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs'])
config['input_idx'] = input_idx
config['output_idx'] = output_idx
return config return config
@classmethod @classmethod
def create_jit_config(cls, config): def create_jit_config(cls, config):
jit_config = {k: config[k] for k in jit_config_keys} jit_config = {k: config[k] for k in jit_config_keys}
return jit_config return jit_config

View File

@@ -1,14 +1,17 @@
from functools import partial """
Crossover two genomes to generate a new genome.
The calculation method is the same as the crossover operation in NEAT-python.
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.configure_crossover
"""
from typing import Tuple from typing import Tuple
import jax import jax
from jax import jit, vmap, Array from jax import jit, Array
from jax import numpy as jnp from jax import numpy as jnp
@jit @jit
def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) \ def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) -> Tuple[Array, Array]:
-> Tuple[Array, Array]:
""" """
use genome1 and genome2 to generate a new genome use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!) notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
@@ -23,7 +26,11 @@ def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2:
# crossover nodes # crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0] keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
# make homologous genes align in nodes2 align with nodes1
nodes2 = align_array(keys1, keys2, nodes2, 'node') nodes2 = align_array(keys1, keys2, nodes2, 'node')
# For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2)) new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections # crossover connections
@@ -34,7 +41,6 @@ def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2:
return new_nodes, new_cons return new_nodes, new_cons
# @partial(jit, static_argnames=['gene_type'])
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array: def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
""" """
After I review this code, I found that it is the most difficult part of the code. Please never change it! After I review this code, I found that it is the most difficult part of the code. Please never change it!
@@ -62,7 +68,6 @@ def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
return refactor_ar2 return refactor_ar2
# @jit
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array: def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
""" """
crossover two genes crossover two genes

View File

@@ -1,6 +1,7 @@
""" """
Calculate the distance between two genomes. Calculate the distance between two genomes.
The calculation method is the same as the distance calculation in NEAT-python. The calculation method is the same as the distance calculation in NEAT-python.
See https://github.com/CodeReclaimers/neat-python/blob/master/neat/genome.py
""" """
from typing import Dict from typing import Dict
@@ -14,6 +15,13 @@ from .utils import EMPTY_NODE, EMPTY_CON
def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, jit_config: Dict) -> Array: def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, jit_config: Dict) -> Array:
""" """
Calculate the distance between two genomes. Calculate the distance between two genomes.
args:
nodes1: Array(N, 5)
cons1: Array(C, 4)
nodes2: Array(N, 5)
cons2: Array(C, 4)
returns:
distance: Array(, )
""" """
nd = node_distance(nodes1, nodes2, jit_config) # node distance nd = node_distance(nodes1, nodes2, jit_config) # node distance
cd = connection_distance(cons1, cons2, jit_config) # connection distance cd = connection_distance(cons1, cons2, jit_config) # connection distance
@@ -23,13 +31,15 @@ def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, jit_confi
@jit @jit
def node_distance(nodes1: Array, nodes2: Array, jit_config: Dict): def node_distance(nodes1: Array, nodes2: Array, jit_config: Dict):
""" """
Calculate the distance between two nodes. Calculate the distance between nodes of two genomes.
""" """
# statistics nodes count of two genomes
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0])) node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0])) node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2) max_cnt = jnp.maximum(node_cnt1, node_cnt2)
# align homologous nodes
# this process is similar to np.intersect1d.
nodes = jnp.concatenate((nodes1, nodes2), axis=0) nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0] keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0) sorted_indices = jnp.argsort(keys, axis=0)
@@ -37,21 +47,28 @@ def node_distance(nodes1: Array, nodes2: Array, jit_config: Dict):
nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row fr, sr = nodes[:-1], nodes[1:] # first row, second row
# flag location of homologous nodes
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0]) intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
# calculate the count of non_homologous of two genomes
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask) non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
nd = batch_homologous_node_distance(fr, sr)
nd = jnp.where(jnp.isnan(nd), 0, nd)
homologous_distance = jnp.sum(nd * intersect_mask)
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe # calculate the distance of homologous nodes
return jnp.where(max_cnt == 0, 0, val / max_cnt) hnd = vmap(homologous_node_distance)(fr, sr)
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[
'compatibility_weight']
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
@jit @jit
def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5): def connection_distance(cons1: Array, cons2: Array, jit_config: Dict):
""" """
Calculate the distance between two connections. Calculate the distance between connections of two genomes.
Similar process as node_distance.
""" """
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0])) con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0])) con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
@@ -68,37 +85,34 @@ def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0]) intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask) non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
cd = batch_homologous_connection_distance(fr, sr) hcd = vmap(homologous_connection_distance)(fr, sr)
cd = jnp.where(jnp.isnan(cd), 0, cd) hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(cd * intersect_mask) homologous_distance = jnp.sum(hcd * intersect_mask)
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[
'compatibility_weight']
return jnp.where(max_cnt == 0, 0, val / max_cnt) return jnp.where(max_cnt == 0, 0, val / max_cnt)
@vmap
def batch_homologous_node_distance(b_n1, b_n2):
return homologous_node_distance(b_n1, b_n2)
@vmap
def batch_homologous_connection_distance(b_c1, b_c2):
return homologous_connection_distance(b_c1, b_c2)
@jit @jit
def homologous_node_distance(n1, n2): def homologous_node_distance(n1: Array, n2: Array):
"""
Calculate the distance between two homologous nodes.
"""
d = 0 d = 0
d += jnp.abs(n1[1] - n2[1]) # bias d += jnp.abs(n1[1] - n2[1]) # bias
d += jnp.abs(n1[2] - n2[2]) # response d += jnp.abs(n1[2] - n2[2]) # response
d += n1[3] != n2[3] # activation d += n1[3] != n2[3] # activation
d += n1[4] != n2[4] d += n1[4] != n2[4] # aggregation
return d return d
@jit @jit
def homologous_connection_distance(c1, c2): def homologous_connection_distance(c1: Array, c2: Array):
"""
Calculate the distance between two homologous connections.
"""
d = 0 d = 0
d += jnp.abs(c1[2] - c2[2]) # weight d += jnp.abs(c1[2] - c2[2]) # weight
d += c1[3] != c2[3] # enable d += c1[3] != c2[3] # enable

View File

@@ -17,10 +17,7 @@ from jax import jit, numpy as jnp
from .utils import fetch_first from .utils import fetch_first
def initialize_genomes(N: int, def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
C: int,
config: Dict) \
-> Tuple[NDArray, NDArray, NDArray, NDArray]:
""" """
Initialize genomes with default values. Initialize genomes with default values.
@@ -41,8 +38,8 @@ def initialize_genomes(N: int,
pop_nodes = np.full((config['pop_size'], N, 5), np.nan) pop_nodes = np.full((config['pop_size'], N, 5), np.nan)
pop_cons = np.full((config['pop_size'], C, 4), np.nan) pop_cons = np.full((config['pop_size'], C, 4), np.nan)
input_idx = np.arange(config['num_inputs']) input_idx = config['input_idx']
output_idx = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs']) output_idx = config['output_idx']
pop_nodes[:, input_idx, 0] = input_idx pop_nodes[:, input_idx, 0] = input_idx
pop_nodes[:, output_idx, 0] = output_idx pop_nodes[:, output_idx, 0] = output_idx
@@ -61,7 +58,7 @@ def initialize_genomes(N: int,
pop_cons[:, :p, 2] = config['weight_init_mean'] pop_cons[:, :p, 2] = config['weight_init_mean']
pop_cons[:, :p, 3] = 1 pop_cons[:, :p, 3] = 1
return pop_nodes, pop_cons, input_idx, output_idx return pop_nodes, pop_cons
def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]: def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:

362
neat/genome/mutate_.py Normal file
View File

@@ -0,0 +1,362 @@
"""
Mutate a genome.
The calculation method is the same as the mutation operation in NEAT-python.
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.mutate
"""
from typing import Tuple, Dict
from functools import partial
import jax
from jax import numpy as jnp
from jax import jit, Array
from .utils import fetch_random, fetch_first, I_INT, unflatten_connections
from .genome_ import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection
from .graph import check_cycles
@jit
def mutate(rand_key: Array, nodes: Array, connections: Array, new_node_key: int, jit_config: Dict):
"""
:param rand_key:
:param nodes: (N, 5)
:param connections: (2, N, N)
:param new_node_key:
:param jit_config:
:return:
"""
def m_add_node(rk, n, c):
return mutate_add_node(rk, n, c, new_node_key, jit_config['bias_init_mean'], jit_config['response_init_mean'],
jit_config['activation_default'], jit_config['aggregation_default'])
def m_add_connection(rk, n, c):
return mutate_add_connection(rk, n, c, jit_config['input_idx'], jit_config['output_idx'])
def m_delete_node(rk, n, c):
return mutate_delete_node(rk, n, c, jit_config['input_idx'], jit_config['output_idx'])
def m_delete_connection(rk, n, c):
return mutate_delete_connection(rk, n, c)
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
# structural mutations
# mutate add node
r = rand(r1)
aux_nodes, aux_connections = m_add_node(r1, nodes, connections)
nodes = jnp.where(r < jit_config['node_add_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['node_add_prob'], aux_connections, connections)
# mutate add connection
r = rand(r2)
aux_nodes, aux_connections = m_add_connection(r3, nodes, connections)
nodes = jnp.where(r < jit_config['conn_add_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['conn_add_prob'], aux_connections, connections)
# mutate delete node
r = rand(r3)
aux_nodes, aux_connections = m_delete_node(r2, nodes, connections)
nodes = jnp.where(r < jit_config['node_delete_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['node_delete_prob'], aux_connections, connections)
# mutate delete connection
r = rand(r4)
aux_nodes, aux_connections = m_delete_connection(r4, nodes, connections)
nodes = jnp.where(r < jit_config['conn_delete_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['conn_delete_prob'], aux_connections, connections)
# value mutations
nodes, connections = mutate_values(rand_key, nodes, connections, jit_config)
return nodes, connections
@jit
def mutate_values(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
"""
Mutate values of nodes and connections.
Args:
rand_key: A random key for generating random values.
nodes: A 2D array representing nodes.
cons: A 3D array representing connections.
jit_config: A dict containing configuration for jit-able functions.
Returns:
A tuple containing mutated nodes and connections.
"""
k1, k2, k3, k4, k5, rand_key = jax.random.split(rand_key, num=6)
bias_new = mutate_float_values(k1, nodes[:, 1], bias_mean, bias_std,
bias_mutate_strength, bias_mutate_rate, bias_replace_rate)
response_new = mutate_float_values(k2, nodes[:, 2], response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate)
weight_new = mutate_float_values(k3, cons[:, 2], weight_mean, weight_std,
weight_mutate_strength, weight_mutate_rate, weight_replace_rate)
act_new = mutate_int_values(k4, nodes[:, 3], act_list, act_replace_rate)
agg_new = mutate_int_values(k5, nodes[:, 4], agg_list, agg_replace_rate)
# mutate enabled
r = jax.random.uniform(rand_key, cons[:, 3].shape)
enabled_new = jnp.where(r < enabled_reverse_rate, 1 - cons[:, 3], cons[:, 3])
enabled_new = jnp.where(~jnp.isnan(cons[:, 3]), enabled_new, jnp.nan)
nodes = nodes.at[:, 1].set(bias_new)
nodes = nodes.at[:, 2].set(response_new)
nodes = nodes.at[:, 3].set(act_new)
nodes = nodes.at[:, 4].set(agg_new)
cons = cons.at[:, 2].set(weight_new)
cons = cons.at[:, 3].set(enabled_new)
return nodes, cons
@jit
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
"""
Mutate float values of a given array.
Args:
rand_key: A random key for generating random values.
old_vals: A 1D array of float values to be mutated.
mean: Mean of the values.
std: Standard deviation of the values.
mutate_strength: Strength of the mutation.
mutate_rate: Rate of the mutation.
replace_rate: Rate of the replacement.
Returns:
A mutated 1D array of float values.
"""
k1, k2, k3, rand_key = jax.random.split(rand_key, num=4)
noise = jax.random.normal(k1, old_vals.shape) * mutate_strength
replace = jax.random.normal(k2, old_vals.shape) * std + mean
r = jax.random.uniform(k3, old_vals.shape)
new_vals = old_vals
new_vals = jnp.where(r < mutate_rate, new_vals + noise, new_vals)
new_vals = jnp.where(
jnp.logical_and(mutate_rate < r, r < mutate_rate + replace_rate),
replace,
new_vals
)
new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan)
return new_vals
@jit
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
"""
Mutate integer values (act, agg) of a given array.
Args:
rand_key: A random key for generating random values.
old_vals: A 1D array of integer values to be mutated.
val_list: List of the integer values.
replace_rate: Rate of the replacement.
Returns:
A mutated 1D array of integer values.
"""
k1, k2, rand_key = jax.random.split(rand_key, num=3)
replace_val = jax.random.choice(k1, val_list, old_vals.shape)
r = jax.random.uniform(k2, old_vals.shape)
new_vals = old_vals
new_vals = jnp.where(r < replace_rate, replace_val, new_vals)
new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan)
return new_vals
@jit
def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: int,
default_bias: float = 0, default_response: float = 1,
default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]:
"""
Randomly add a new node from splitting a connection.
:param rand_key:
:param new_node_key:
:param nodes:
:param cons:
:param default_bias:
:param default_response:
:param default_act:
:param default_agg:
:return:
"""
# randomly choose a connection
i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons)
def nothing(): # there is no connection to split
return nodes, cons
def successful_add_node():
# disable the connection
new_nodes, new_cons = nodes, cons
new_cons = new_cons.at[idx, 3].set(False)
# add a new node
new_nodes, new_cons = \
add_node(new_nodes, new_cons, new_node_key,
bias=default_bias, response=default_response, act=default_act, agg=default_agg)
# add two new connections
w = new_cons[idx, 2]
new_nodes, new_cons = add_connection(new_nodes, new_cons, i_key, new_node_key, weight=1, enabled=True)
new_nodes, new_cons = add_connection(new_nodes, new_cons, new_node_key, o_key, weight=w, enabled=True)
return new_nodes, new_cons
# if from_idx == I_INT, that means no connection exist, do nothing
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_add_node)
return nodes, cons
# TODO: Need we really need to delete a node?
@jit
def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array,
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
"""
Randomly delete a node. Input and output nodes are not allowed to be deleted.
:param rand_key:
:param nodes:
:param cons:
:param input_keys:
:param output_keys:
:return:
"""
# randomly choose a node
node_key, node_idx = choice_node_key(rand_key, nodes, input_keys, output_keys,
allow_input_keys=False, allow_output_keys=False)
def nothing():
return nodes, cons
def successful_delete_node():
# delete the node
aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, node_idx)
# delete all connections
aux_cons = jnp.where(((aux_cons[:, 0] == node_key) | (aux_cons[:, 1] == node_key))[:, jnp.newaxis],
jnp.nan, aux_cons)
return aux_nodes, aux_cons
nodes, cons = jax.lax.cond(node_idx == I_INT, nothing, successful_delete_node)
return nodes, cons
@jit
def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array,
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
"""
Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks,
cycles are not allowed.
:param rand_key:
:param nodes:
:param cons:
:param input_keys:
:param output_keys:
:return:
"""
# randomly choose two nodes
k1, k2 = jax.random.split(rand_key, num=2)
i_key, from_idx = choice_node_key(k1, nodes, input_keys, output_keys,
allow_input_keys=True, allow_output_keys=True)
o_key, to_idx = choice_node_key(k2, nodes, input_keys, output_keys,
allow_input_keys=False, allow_output_keys=True)
con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
def successful():
new_nodes, new_cons = add_connection(nodes, cons, i_key, o_key, weight=1, enabled=True)
return new_nodes, new_cons
def already_exist():
new_cons = cons.at[con_idx, 3].set(True)
return nodes, new_cons
def cycle():
return nodes, cons
is_already_exist = con_idx != I_INT
unflattened = unflatten_connections(nodes, cons)
is_cycle = check_cycles(nodes, unflattened, from_idx, to_idx)
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful])
return nodes, cons
@jit
def mutate_delete_connection(rand_key: Array, nodes: Array, cons: Array):
"""
Randomly delete a connection.
:param rand_key:
:param nodes:
:param cons:
:return:
"""
# randomly choose a connection
i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons)
def nothing():
return nodes, cons
def successfully_delete_connection():
return delete_connection_by_idx(nodes, cons, idx)
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection)
return nodes, cons
@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
def choice_node_key(rand_key: Array, nodes: Array,
input_keys: Array, output_keys: Array,
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
"""
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
:param rand_key:
:param nodes:
:param input_keys:
:param output_keys:
:param allow_input_keys:
:param allow_output_keys:
:return: return its key and position(idx)
"""
node_keys = nodes[:, 0]
mask = ~jnp.isnan(node_keys)
if not allow_input_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys))
if not allow_output_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys))
idx = fetch_random(rand_key, mask)
key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan)
return key, idx
@jit
def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[Array, Array, Array]:
"""
Randomly choose a connection key from the given connections.
:param rand_key:
:param nodes:
:param cons:
:return: i_key, o_key, idx
"""
idx = fetch_random(rand_key, ~jnp.isnan(cons[:, 0]))
i_key = jnp.where(idx != I_INT, cons[idx, 0], jnp.nan)
o_key = jnp.where(idx != I_INT, cons[idx, 1], jnp.nan)
return i_key, o_key, idx
@jit
def rand(rand_key):
return jax.random.uniform(rand_key, ())

View File

@@ -21,7 +21,7 @@ class Pipeline:
self.generation = 0 self.generation = 0
self.best_genome = None self.best_genome = None
self.pop_nodes, self.pop_cons, self.input_idx, self.output_idx = initialize_genomes(self.N, self.C, self.config) self.pop_nodes, self.pop_cons = initialize_genomes(self.N, self.C, self.config)
print(self.pop_nodes, self.pop_cons, self.input_idx, self.output_idx, sep='\n') print(self.pop_nodes, self.pop_cons, sep='\n')
print(self.jit_config) print(self.jit_config)