change a lot a lot a lot!!!!!!!
This commit is contained in:
2
algorithm/neat/ga/__init__.py
Normal file
2
algorithm/neat/ga/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .crossover import crossover
|
||||
from .mutate import create_mutate
|
||||
70
algorithm/neat/ga/crossover.py
Normal file
70
algorithm/neat/ga/crossover.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import jax
|
||||
from jax import Array, numpy as jnp
|
||||
|
||||
from core import Genome
|
||||
|
||||
|
||||
def crossover(randkey, genome1: Genome, genome2: Genome):
|
||||
"""
|
||||
use genome1 and genome2 to generate a new genome
|
||||
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||
"""
|
||||
randkey_1, randkey_2, key= jax.random.split(randkey, 3)
|
||||
|
||||
# crossover nodes
|
||||
keys1, keys2 = genome1.nodes[:, 0], genome2.nodes[:, 0]
|
||||
# make homologous genes align in nodes2 align with nodes1
|
||||
nodes2 = align_array(keys1, keys2, genome2.nodes, False)
|
||||
nodes1 = genome1.nodes
|
||||
# For not homologous genes, use the value of nodes1(winner)
|
||||
# For homologous genes, use the crossover result between nodes1 and nodes2
|
||||
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
|
||||
|
||||
# crossover connections
|
||||
con_keys1, con_keys2 = genome1.conns[:, :2], genome2.conns[:, :2]
|
||||
conns2 = align_array(con_keys1, con_keys2, genome2.conns, True)
|
||||
conns1 = genome1.conns
|
||||
|
||||
new_cons = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, crossover_gene(randkey_2, conns1, conns2))
|
||||
|
||||
return genome1.update(new_nodes, new_cons)
|
||||
|
||||
|
||||
def align_array(seq1: Array, seq2: Array, ar2: Array, is_conn: bool) -> Array:
|
||||
"""
|
||||
After I review this code, I found that it is the most difficult part of the code. Please never change it!
|
||||
make ar2 align with ar1.
|
||||
:param seq1:
|
||||
:param seq2:
|
||||
:param ar2:
|
||||
:param is_conn:
|
||||
:return:
|
||||
align means to intersect part of ar2 will be at the same position as ar1,
|
||||
non-intersect part of ar2 will be set to Nan
|
||||
"""
|
||||
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
|
||||
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
|
||||
|
||||
if is_conn:
|
||||
mask = jnp.all(mask, axis=2)
|
||||
|
||||
intersect_mask = mask.any(axis=1)
|
||||
idx = jnp.arange(0, len(seq1))
|
||||
idx_fixed = jnp.dot(mask, idx)
|
||||
|
||||
refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan)
|
||||
|
||||
return refactor_ar2
|
||||
|
||||
|
||||
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
|
||||
"""
|
||||
crossover two genes
|
||||
:param rand_key:
|
||||
:param g1:
|
||||
:param g2:
|
||||
:return:
|
||||
only gene with the same key will be crossover, thus don't need to consider change key
|
||||
"""
|
||||
r = jax.random.uniform(rand_key, shape=g1.shape)
|
||||
return jnp.where(r > 0.5, g1, g2)
|
||||
189
algorithm/neat/ga/mutate.py
Normal file
189
algorithm/neat/ga/mutate.py
Normal file
@@ -0,0 +1,189 @@
|
||||
from typing import Tuple, Type
|
||||
|
||||
import jax
|
||||
from jax import Array, numpy as jnp, vmap
|
||||
|
||||
from config import NeatConfig
|
||||
from core import State, Gene, Genome
|
||||
from utils import check_cycles, fetch_random, fetch_first, I_INT, unflatten_conns
|
||||
|
||||
|
||||
def create_mutate(config: NeatConfig, gene_type: Type[Gene]):
|
||||
"""
|
||||
Create function to mutate a single genome
|
||||
"""
|
||||
|
||||
def mutate_structure(state: State, randkey, genome: Genome, new_node_key):
|
||||
|
||||
def mutate_add_node(key_, genome_: Genome):
|
||||
i_key, o_key, idx = choice_connection_key(key_, genome_.conns)
|
||||
|
||||
def nothing():
|
||||
return genome_
|
||||
|
||||
def successful_add_node():
|
||||
# disable the connection
|
||||
new_genome = genome_.update_conns(genome_.conns.at[idx, 2].set(False))
|
||||
|
||||
# add a new node
|
||||
new_genome = new_genome.add_node(new_node_key, gene_type.new_node_attrs(state))
|
||||
|
||||
# add two new connections
|
||||
new_genome = new_genome.add_conn(i_key, new_node_key, True, gene_type.new_conn_attrs(state))
|
||||
new_genome = new_genome.add_conn(new_node_key, o_key, True, gene_type.new_conn_attrs(state))
|
||||
|
||||
return new_genome
|
||||
|
||||
# if from_idx == I_INT, that means no connection exist, do nothing
|
||||
return jax.lax.cond(idx == I_INT, nothing, successful_add_node)
|
||||
|
||||
def mutate_delete_node(key_, genome_: Genome):
|
||||
# TODO: Do we really need to delete a node?
|
||||
# randomly choose a node
|
||||
key, idx = choice_node_key(key_, genome_.nodes, state.input_idx, state.output_idx,
|
||||
allow_input_keys=False, allow_output_keys=False)
|
||||
def nothing():
|
||||
return genome_
|
||||
|
||||
def successful_delete_node():
|
||||
# delete the node
|
||||
new_genome = genome_.delete_node_by_pos(idx)
|
||||
|
||||
# delete all connections
|
||||
new_conns = jnp.where(((new_genome.conns[:, 0] == key) | (new_genome.conns[:, 1] == key))[:, None],
|
||||
jnp.nan, new_genome.conns)
|
||||
|
||||
return new_genome.update_conns(new_conns)
|
||||
|
||||
return jax.lax.cond(idx == I_INT, nothing, successful_delete_node)
|
||||
|
||||
def mutate_add_conn(key_, genome_: Genome):
|
||||
# randomly choose two nodes
|
||||
k1_, k2_ = jax.random.split(key_, num=2)
|
||||
i_key, from_idx = choice_node_key(k1_, genome_.nodes, state.input_idx, state.output_idx,
|
||||
allow_input_keys=True, allow_output_keys=True)
|
||||
o_key, to_idx = choice_node_key(k2_, genome_.nodes, state.input_idx, state.output_idx,
|
||||
allow_input_keys=False, allow_output_keys=True)
|
||||
|
||||
conn_pos = fetch_first((genome_.conns[:, 0] == i_key) & (genome_.conns[:, 1] == o_key))
|
||||
|
||||
def nothing():
|
||||
return genome_
|
||||
|
||||
def successful():
|
||||
return genome_.add_conn(i_key, o_key, True, gene_type.new_conn_attrs(state))
|
||||
|
||||
def already_exist():
|
||||
return genome_.update_conns(genome_.conns.at[conn_pos, 2].set(True))
|
||||
|
||||
|
||||
is_already_exist = conn_pos != I_INT
|
||||
|
||||
if config.network_type == 'feedforward':
|
||||
u_cons = unflatten_conns(genome_.nodes, genome_.conns)
|
||||
cons_exist = jnp.where(~jnp.isnan(u_cons[0, :, :]), True, False)
|
||||
is_cycle = check_cycles(genome_.nodes, cons_exist, from_idx, to_idx)
|
||||
|
||||
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
|
||||
return jax.lax.switch(choice, [already_exist, nothing, successful])
|
||||
|
||||
elif config.network_type == 'recurrent':
|
||||
return jax.lax.cond(is_already_exist, already_exist, successful)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid network type: {config.network_type}")
|
||||
|
||||
def mutate_delete_conn(key_, genome_: Genome):
|
||||
# randomly choose a connection
|
||||
i_key, o_key, idx = choice_connection_key(key_, genome_.conns)
|
||||
|
||||
def nothing():
|
||||
return genome_
|
||||
|
||||
def successfully_delete_connection():
|
||||
return genome_.delete_conn_by_pos(idx)
|
||||
|
||||
return jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection)
|
||||
|
||||
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
|
||||
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
|
||||
|
||||
def no(k, g):
|
||||
return g
|
||||
|
||||
genome = jax.lax.cond(r1 < config.node_add, mutate_add_node, no, k1, genome)
|
||||
genome = jax.lax.cond(r2 < config.node_delete, mutate_delete_node, no, k2, genome)
|
||||
genome = jax.lax.cond(r3 < config.conn_add, mutate_add_conn, no, k3, genome)
|
||||
genome = jax.lax.cond(r4 < config.conn_delete, mutate_delete_conn, no, k4, genome)
|
||||
|
||||
return genome
|
||||
|
||||
def mutate_values(state: State, randkey, genome: Genome):
|
||||
k1, k2 = jax.random.split(randkey, num=2)
|
||||
nodes_keys = jax.random.split(k1, num=genome.nodes.shape[0])
|
||||
conns_keys = jax.random.split(k2, num=genome.conns.shape[0])
|
||||
|
||||
nodes_attrs, conns_attrs = genome.nodes[:, 1:], genome.conns[:, 3:]
|
||||
|
||||
new_nodes_attrs = vmap(gene_type.mutate_node, in_axes=(None, 0, 0))(state, nodes_attrs, nodes_keys)
|
||||
new_conns_attrs = vmap(gene_type.mutate_conn, in_axes=(None, 0, 0))(state, conns_attrs, conns_keys)
|
||||
|
||||
# nan nodes not changed
|
||||
new_nodes_attrs = jnp.where(jnp.isnan(nodes_attrs), jnp.nan, new_nodes_attrs)
|
||||
new_conns_attrs = jnp.where(jnp.isnan(conns_attrs), jnp.nan, new_conns_attrs)
|
||||
|
||||
new_nodes = genome.nodes.at[:, 1:].set(new_nodes_attrs)
|
||||
new_conns = genome.conns.at[:, 3:].set(new_conns_attrs)
|
||||
|
||||
return genome.update(new_nodes, new_conns)
|
||||
|
||||
def mutate(state, randkey, genome: Genome, new_node_key):
|
||||
k1, k2 = jax.random.split(randkey)
|
||||
|
||||
genome = mutate_structure(state, k1, genome, new_node_key)
|
||||
genome = mutate_values(state, k2, genome)
|
||||
|
||||
return genome
|
||||
|
||||
return mutate
|
||||
|
||||
|
||||
def choice_node_key(rand_key: Array, nodes: Array,
|
||||
input_keys: Array, output_keys: Array,
|
||||
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
|
||||
:param rand_key:
|
||||
:param nodes:
|
||||
:param input_keys:
|
||||
:param output_keys:
|
||||
:param allow_input_keys:
|
||||
:param allow_output_keys:
|
||||
:return: return its key and position(idx)
|
||||
"""
|
||||
|
||||
node_keys = nodes[:, 0]
|
||||
mask = ~jnp.isnan(node_keys)
|
||||
|
||||
if not allow_input_keys:
|
||||
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys))
|
||||
|
||||
if not allow_output_keys:
|
||||
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys))
|
||||
|
||||
idx = fetch_random(rand_key, mask)
|
||||
key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan)
|
||||
return key, idx
|
||||
|
||||
|
||||
def choice_connection_key(rand_key: Array, conns: Array):
|
||||
"""
|
||||
Randomly choose a connection key from the given connections.
|
||||
:return: i_key, o_key, idx
|
||||
"""
|
||||
|
||||
idx = fetch_random(rand_key, ~jnp.isnan(conns[:, 0]))
|
||||
i_key = jnp.where(idx != I_INT, conns[idx, 0], jnp.nan)
|
||||
o_key = jnp.where(idx != I_INT, conns[idx, 1], jnp.nan)
|
||||
|
||||
return i_key, o_key, idx
|
||||
Reference in New Issue
Block a user