odify genome for the official release
This commit is contained in:
3
tensorneat/algorithm/neat/genome/operations/__init__.py
Normal file
3
tensorneat/algorithm/neat/genome/operations/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .crossover import BaseCrossover, DefaultCrossover
|
||||
from .mutation import BaseMutation, DefaultMutation
|
||||
from .distance import BaseDistance, DefaultDistance
|
||||
@@ -0,0 +1,2 @@
|
||||
from .base import BaseCrossover
|
||||
from .default import DefaultCrossover
|
||||
@@ -0,0 +1,12 @@
|
||||
from tensorneat.common import StatefulBaseClass, State
|
||||
|
||||
|
||||
class BaseCrossover(StatefulBaseClass):
|
||||
|
||||
def setup(self, state=State(), genome = None):
|
||||
assert genome is not None, "genome should not be None"
|
||||
self.genome = genome
|
||||
return state
|
||||
|
||||
def __call__(self, state, randkey, nodes1, nodes2, conns1, conns2):
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,87 @@
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
|
||||
from .base import BaseCrossover
|
||||
from ...utils import (
|
||||
extract_node_attrs,
|
||||
extract_conn_attrs,
|
||||
set_node_attrs,
|
||||
set_conn_attrs,
|
||||
)
|
||||
|
||||
|
||||
class DefaultCrossover(BaseCrossover):
|
||||
def __call__(self, state, randkey, nodes1, conns1, nodes2, conns2):
|
||||
"""
|
||||
use genome1 and genome2 to generate a new genome
|
||||
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||
"""
|
||||
randkey1, randkey2 = jax.random.split(randkey, 2)
|
||||
randkeys1 = jax.random.split(randkey1, self.genome.max_nodes)
|
||||
randkeys2 = jax.random.split(randkey2, self.genome.max_conns)
|
||||
|
||||
# crossover nodes
|
||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||
# make homologous genes align in nodes2 align with nodes1
|
||||
nodes2 = self.align_array(keys1, keys2, nodes2, is_conn=False)
|
||||
|
||||
# For not homologous genes, use the value of nodes1(winner)
|
||||
# For homologous genes, use the crossover result between nodes1 and nodes2
|
||||
node_attrs1 = vmap(extract_node_attrs)(nodes1)
|
||||
node_attrs2 = vmap(extract_node_attrs)(nodes2)
|
||||
|
||||
new_node_attrs = jnp.where(
|
||||
jnp.isnan(node_attrs1) | jnp.isnan(node_attrs2), # one of them is nan
|
||||
node_attrs1, # not homologous genes or both nan, use the value of nodes1(winner)
|
||||
vmap(self.genome.node_gene.crossover, in_axes=(None, 0, 0, 0))(
|
||||
state, randkeys1, node_attrs1, node_attrs2
|
||||
), # homologous or both nan
|
||||
)
|
||||
new_nodes = vmap(set_node_attrs)(nodes1, new_node_attrs)
|
||||
|
||||
# crossover connections
|
||||
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
|
||||
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
|
||||
|
||||
conns_attrs1 = vmap(extract_conn_attrs)(conns1)
|
||||
conns_attrs2 = vmap(extract_conn_attrs)(conns2)
|
||||
|
||||
new_conn_attrs = jnp.where(
|
||||
jnp.isnan(conns_attrs1) | jnp.isnan(conns_attrs2),
|
||||
conns_attrs1, # not homologous genes or both nan, use the value of conns1(winner)
|
||||
vmap(self.genome.conn_gene.crossover, in_axes=(None, 0, 0, 0))(
|
||||
state, randkeys2, conns_attrs1, conns_attrs2
|
||||
), # homologous or both nan
|
||||
)
|
||||
new_conns = vmap(set_conn_attrs)(conns1, new_conn_attrs)
|
||||
|
||||
return new_nodes, new_conns
|
||||
|
||||
def align_array(self, seq1, seq2, ar2, is_conn: bool):
|
||||
"""
|
||||
After I review this code, I found that it is the most difficult part of the code.
|
||||
Please consider carefully before change it!
|
||||
make ar2 align with ar1.
|
||||
:param seq1:
|
||||
:param seq2:
|
||||
:param ar2:
|
||||
:param is_conn:
|
||||
:return:
|
||||
align means to intersect part of ar2 will be at the same position as ar1,
|
||||
non-intersect part of ar2 will be set to Nan
|
||||
"""
|
||||
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
|
||||
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
|
||||
|
||||
if is_conn:
|
||||
mask = jnp.all(mask, axis=2)
|
||||
|
||||
intersect_mask = mask.any(axis=1)
|
||||
idx = jnp.arange(0, len(seq1))
|
||||
idx_fixed = jnp.dot(mask, idx)
|
||||
|
||||
refactor_ar2 = jnp.where(
|
||||
intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan
|
||||
)
|
||||
|
||||
return refactor_ar2
|
||||
@@ -0,0 +1,2 @@
|
||||
from .base import BaseDistance
|
||||
from .default import DefaultDistance
|
||||
15
tensorneat/algorithm/neat/genome/operations/distance/base.py
Normal file
15
tensorneat/algorithm/neat/genome/operations/distance/base.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from tensorneat.common import StatefulBaseClass, State
|
||||
|
||||
|
||||
class BaseDistance(StatefulBaseClass):
|
||||
|
||||
def setup(self, state=State(), genome = None):
|
||||
assert genome is not None, "genome should not be None"
|
||||
self.genome = genome
|
||||
return state
|
||||
|
||||
def __call__(self, state, nodes1, nodes2, conns1, conns2):
|
||||
"""
|
||||
The distance between two genomes
|
||||
"""
|
||||
raise NotImplementedError
|
||||
105
tensorneat/algorithm/neat/genome/operations/distance/default.py
Normal file
105
tensorneat/algorithm/neat/genome/operations/distance/default.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from jax import vmap, numpy as jnp
|
||||
|
||||
from .base import BaseDistance
|
||||
from ...utils import extract_node_attrs, extract_conn_attrs
|
||||
|
||||
|
||||
class DefaultDistance(BaseDistance):
|
||||
def __init__(
|
||||
self,
|
||||
compatibility_disjoint: float = 1.0,
|
||||
compatibility_weight: float = 0.4,
|
||||
):
|
||||
self.compatibility_disjoint = compatibility_disjoint
|
||||
self.compatibility_weight = compatibility_weight
|
||||
|
||||
def __call__(self, state, nodes1, nodes2, conns1, conns2):
|
||||
"""
|
||||
The distance between two genomes
|
||||
"""
|
||||
d = self.node_distance(state, nodes1, nodes2) + self.conn_distance(
|
||||
state, conns1, conns2
|
||||
)
|
||||
return d
|
||||
|
||||
def node_distance(self, state, nodes1, nodes2):
|
||||
"""
|
||||
The distance of the nodes part for two genomes
|
||||
"""
|
||||
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
|
||||
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
|
||||
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
|
||||
|
||||
# align homologous nodes
|
||||
# this process is similar to np.intersect1d.
|
||||
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
|
||||
keys = nodes[:, 0]
|
||||
sorted_indices = jnp.argsort(keys, axis=0)
|
||||
nodes = nodes[sorted_indices]
|
||||
nodes = jnp.concatenate(
|
||||
[nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0
|
||||
) # add a nan row to the end
|
||||
fr, sr = nodes[:-1], nodes[1:] # first row, second row
|
||||
|
||||
# flag location of homologous nodes
|
||||
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
|
||||
|
||||
# calculate the count of non_homologous of two genomes
|
||||
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
|
||||
# calculate the distance of homologous nodes
|
||||
fr_attrs = vmap(extract_node_attrs)(fr)
|
||||
sr_attrs = vmap(extract_node_attrs)(sr)
|
||||
hnd = vmap(self.genome.node_gene.distance, in_axes=(None, 0, 0))(
|
||||
state, fr_attrs, sr_attrs
|
||||
) # homologous node distance
|
||||
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
|
||||
homologous_distance = jnp.sum(hnd * intersect_mask)
|
||||
|
||||
val = (
|
||||
non_homologous_cnt * self.compatibility_disjoint
|
||||
+ homologous_distance * self.compatibility_weight
|
||||
)
|
||||
|
||||
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
|
||||
|
||||
return val
|
||||
|
||||
def conn_distance(self, state, conns1, conns2):
|
||||
"""
|
||||
The distance of the conns part for two genomes
|
||||
"""
|
||||
con_cnt1 = jnp.sum(~jnp.isnan(conns1[:, 0]))
|
||||
con_cnt2 = jnp.sum(~jnp.isnan(conns2[:, 0]))
|
||||
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
|
||||
|
||||
cons = jnp.concatenate((conns1, conns2), axis=0)
|
||||
keys = cons[:, :2]
|
||||
sorted_indices = jnp.lexsort(keys.T[::-1])
|
||||
cons = cons[sorted_indices]
|
||||
cons = jnp.concatenate(
|
||||
[cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0
|
||||
) # add a nan row to the end
|
||||
fr, sr = cons[:-1], cons[1:] # first row, second row
|
||||
|
||||
# both genome has such connection
|
||||
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
|
||||
|
||||
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
|
||||
fr_attrs = vmap(extract_conn_attrs)(fr)
|
||||
sr_attrs = vmap(extract_conn_attrs)(sr)
|
||||
hcd = vmap(self.genome.conn_gene.distance, in_axes=(None, 0, 0))(
|
||||
state, fr_attrs, sr_attrs
|
||||
) # homologous connection distance
|
||||
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
|
||||
homologous_distance = jnp.sum(hcd * intersect_mask)
|
||||
|
||||
val = (
|
||||
non_homologous_cnt * self.compatibility_disjoint
|
||||
+ homologous_distance * self.compatibility_weight
|
||||
)
|
||||
|
||||
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
|
||||
|
||||
return val
|
||||
@@ -0,0 +1,2 @@
|
||||
from .base import BaseMutation
|
||||
from .default import DefaultMutation
|
||||
12
tensorneat/algorithm/neat/genome/operations/mutation/base.py
Normal file
12
tensorneat/algorithm/neat/genome/operations/mutation/base.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from tensorneat.common import StatefulBaseClass, State
|
||||
|
||||
|
||||
class BaseMutation(StatefulBaseClass):
|
||||
|
||||
def setup(self, state=State(), genome = None):
|
||||
assert genome is not None, "genome should not be None"
|
||||
self.genome = genome
|
||||
return state
|
||||
|
||||
def __call__(self, state, randkey, genome, nodes, conns, new_node_key):
|
||||
raise NotImplementedError
|
||||
292
tensorneat/algorithm/neat/genome/operations/mutation/default.py
Normal file
292
tensorneat/algorithm/neat/genome/operations/mutation/default.py
Normal file
@@ -0,0 +1,292 @@
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
from . import BaseMutation
|
||||
from tensorneat.common import (
|
||||
fetch_first,
|
||||
fetch_random,
|
||||
I_INF,
|
||||
check_cycles,
|
||||
)
|
||||
from ...utils import (
|
||||
unflatten_conns,
|
||||
add_node,
|
||||
add_conn,
|
||||
delete_node_by_pos,
|
||||
delete_conn_by_pos,
|
||||
extract_node_attrs,
|
||||
extract_conn_attrs,
|
||||
set_node_attrs,
|
||||
set_conn_attrs,
|
||||
)
|
||||
|
||||
|
||||
class DefaultMutation(BaseMutation):
|
||||
def __init__(
|
||||
self,
|
||||
conn_add: float = 0.2,
|
||||
conn_delete: float = 0,
|
||||
node_add: float = 0.2,
|
||||
node_delete: float = 0,
|
||||
):
|
||||
self.conn_add = conn_add
|
||||
self.conn_delete = conn_delete
|
||||
self.node_add = node_add
|
||||
self.node_delete = node_delete
|
||||
|
||||
def __call__(self, state, randkey, genome, nodes, conns, new_node_key):
|
||||
k1, k2 = jax.random.split(randkey)
|
||||
|
||||
nodes, conns = self.mutate_structure(
|
||||
state, k1, genome, nodes, conns, new_node_key
|
||||
)
|
||||
nodes, conns = self.mutate_values(state, k2, genome, nodes, conns)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
def mutate_structure(self, state, randkey, genome, nodes, conns, new_node_key):
|
||||
def mutate_add_node(key_, nodes_, conns_):
|
||||
"""
|
||||
add a node while do not influence the output of the network
|
||||
"""
|
||||
|
||||
remain_node_space = jnp.isnan(nodes_[:, 0]).sum()
|
||||
remain_conn_space = jnp.isnan(conns_[:, 0]).sum()
|
||||
i_key, o_key, idx = self.choose_connection_key(
|
||||
key_, conns_
|
||||
) # choose a connection
|
||||
|
||||
def successful_add_node():
|
||||
# remove the original connection and record its attrs
|
||||
original_attrs = extract_conn_attrs(conns_[idx])
|
||||
new_conns = delete_conn_by_pos(conns_, idx)
|
||||
|
||||
# add a new node with identity attrs
|
||||
new_nodes = add_node(
|
||||
nodes_, new_node_key, genome.node_gene.new_identity_attrs(state)
|
||||
)
|
||||
|
||||
# add two new connections
|
||||
# first is with identity attrs
|
||||
new_conns = add_conn(
|
||||
new_conns,
|
||||
i_key,
|
||||
new_node_key,
|
||||
genome.conn_gene.new_identity_attrs(state),
|
||||
)
|
||||
# second is with the origin attrs
|
||||
new_conns = add_conn(
|
||||
new_conns,
|
||||
new_node_key,
|
||||
o_key,
|
||||
original_attrs,
|
||||
)
|
||||
|
||||
return new_nodes, new_conns
|
||||
|
||||
return jax.lax.cond(
|
||||
(idx == I_INF) | (remain_node_space < 1) | (remain_conn_space < 2),
|
||||
lambda: (nodes_, conns_), # do nothing
|
||||
successful_add_node,
|
||||
)
|
||||
|
||||
def mutate_delete_node(key_, nodes_, conns_):
|
||||
"""
|
||||
delete a node
|
||||
"""
|
||||
# randomly choose a node
|
||||
key, idx = self.choose_node_key(
|
||||
key_,
|
||||
nodes_,
|
||||
genome.input_idx,
|
||||
genome.output_idx,
|
||||
allow_input_keys=False,
|
||||
allow_output_keys=False,
|
||||
)
|
||||
|
||||
def successful_delete_node():
|
||||
# delete the node
|
||||
new_nodes = delete_node_by_pos(nodes_, idx)
|
||||
|
||||
# delete all connections
|
||||
new_conns = jnp.where(
|
||||
((conns_[:, 0] == key) | (conns_[:, 1] == key))[:, None],
|
||||
jnp.nan,
|
||||
conns_,
|
||||
)
|
||||
|
||||
return new_nodes, new_conns
|
||||
|
||||
return jax.lax.cond(
|
||||
idx == I_INF, # no available node to delete
|
||||
lambda: (nodes_, conns_), # do nothing
|
||||
successful_delete_node,
|
||||
)
|
||||
|
||||
def mutate_add_conn(key_, nodes_, conns_):
|
||||
"""
|
||||
add a connection while do not influence the output of the network
|
||||
"""
|
||||
|
||||
remain_conn_space = jnp.isnan(conns_[:, 0]).sum()
|
||||
|
||||
# randomly choose two nodes
|
||||
k1_, k2_ = jax.random.split(key_, num=2)
|
||||
|
||||
# input node of the connection can be any node
|
||||
i_key, from_idx = self.choose_node_key(
|
||||
k1_,
|
||||
nodes_,
|
||||
genome.input_idx,
|
||||
genome.output_idx,
|
||||
allow_input_keys=True,
|
||||
allow_output_keys=True,
|
||||
)
|
||||
|
||||
# output node of the connection can be any node except input node
|
||||
o_key, to_idx = self.choose_node_key(
|
||||
k2_,
|
||||
nodes_,
|
||||
genome.input_idx,
|
||||
genome.output_idx,
|
||||
allow_input_keys=False,
|
||||
allow_output_keys=True,
|
||||
)
|
||||
|
||||
conn_pos = fetch_first((conns_[:, 0] == i_key) & (conns_[:, 1] == o_key))
|
||||
is_already_exist = conn_pos != I_INF
|
||||
|
||||
def nothing():
|
||||
return nodes_, conns_
|
||||
|
||||
def successful():
|
||||
# add a connection with zero attrs
|
||||
return nodes_, add_conn(
|
||||
conns_, i_key, o_key, genome.conn_gene.new_zero_attrs(state)
|
||||
)
|
||||
|
||||
if genome.network_type == "feedforward":
|
||||
u_conns = unflatten_conns(nodes_, conns_)
|
||||
conns_exist = u_conns != I_INF
|
||||
is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx)
|
||||
|
||||
return jax.lax.cond(
|
||||
is_already_exist | is_cycle | (remain_conn_space < 1),
|
||||
nothing,
|
||||
successful,
|
||||
)
|
||||
|
||||
elif genome.network_type == "recurrent":
|
||||
return jax.lax.cond(
|
||||
is_already_exist | (remain_conn_space < 1),
|
||||
nothing,
|
||||
successful,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid network type: {genome.network_type}")
|
||||
|
||||
def mutate_delete_conn(key_, nodes_, conns_):
|
||||
# randomly choose a connection
|
||||
i_key, o_key, idx = self.choose_connection_key(key_, conns_)
|
||||
|
||||
return jax.lax.cond(
|
||||
idx == I_INF,
|
||||
lambda: (nodes_, conns_), # nothing
|
||||
lambda: (nodes_, delete_conn_by_pos(conns_, idx)), # success
|
||||
)
|
||||
|
||||
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
|
||||
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
|
||||
|
||||
def nothing(_, nodes_, conns_):
|
||||
return nodes_, conns_
|
||||
|
||||
if self.node_add > 0:
|
||||
nodes, conns = jax.lax.cond(
|
||||
r1 < self.node_add, mutate_add_node, nothing, k1, nodes, conns
|
||||
)
|
||||
|
||||
if self.node_delete > 0:
|
||||
nodes, conns = jax.lax.cond(
|
||||
r2 < self.node_delete, mutate_delete_node, nothing, k2, nodes, conns
|
||||
)
|
||||
|
||||
if self.conn_add > 0:
|
||||
nodes, conns = jax.lax.cond(
|
||||
r3 < self.conn_add, mutate_add_conn, nothing, k3, nodes, conns
|
||||
)
|
||||
|
||||
if self.conn_delete > 0:
|
||||
nodes, conns = jax.lax.cond(
|
||||
r4 < self.conn_delete, mutate_delete_conn, nothing, k4, nodes, conns
|
||||
)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
def mutate_values(self, state, randkey, genome, nodes, conns):
|
||||
k1, k2 = jax.random.split(randkey)
|
||||
nodes_randkeys = jax.random.split(k1, num=genome.max_nodes)
|
||||
conns_randkeys = jax.random.split(k2, num=genome.max_conns)
|
||||
|
||||
node_attrs = vmap(extract_node_attrs)(nodes)
|
||||
new_node_attrs = vmap(genome.node_gene.mutate, in_axes=(None, 0, 0))(
|
||||
state, nodes_randkeys, node_attrs
|
||||
)
|
||||
new_nodes = vmap(set_node_attrs)(nodes, new_node_attrs)
|
||||
|
||||
conn_attrs = vmap(extract_conn_attrs)(conns)
|
||||
new_conn_attrs = vmap(genome.conn_gene.mutate, in_axes=(None, 0, 0))(
|
||||
state, conns_randkeys, conn_attrs
|
||||
)
|
||||
new_conns = vmap(set_conn_attrs)(conns, new_conn_attrs)
|
||||
|
||||
# nan nodes not changed
|
||||
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)
|
||||
new_conns = jnp.where(jnp.isnan(conns), jnp.nan, new_conns)
|
||||
|
||||
return new_nodes, new_conns
|
||||
|
||||
def choose_node_key(
|
||||
self,
|
||||
key,
|
||||
nodes,
|
||||
input_idx,
|
||||
output_idx,
|
||||
allow_input_keys: bool = False,
|
||||
allow_output_keys: bool = False,
|
||||
):
|
||||
"""
|
||||
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
|
||||
:param key:
|
||||
:param nodes:
|
||||
:param input_idx:
|
||||
:param output_idx:
|
||||
:param allow_input_keys:
|
||||
:param allow_output_keys:
|
||||
:return: return its key and position(idx)
|
||||
"""
|
||||
|
||||
node_keys = nodes[:, 0]
|
||||
mask = ~jnp.isnan(node_keys)
|
||||
|
||||
if not allow_input_keys:
|
||||
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_idx))
|
||||
|
||||
if not allow_output_keys:
|
||||
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_idx))
|
||||
|
||||
idx = fetch_random(key, mask)
|
||||
key = jnp.where(idx != I_INF, nodes[idx, 0], jnp.nan)
|
||||
return key, idx
|
||||
|
||||
def choose_connection_key(self, key, conns):
|
||||
"""
|
||||
Randomly choose a connection key from the given connections.
|
||||
:return: i_key, o_key, idx
|
||||
"""
|
||||
|
||||
idx = fetch_random(key, ~jnp.isnan(conns[:, 0]))
|
||||
i_key = jnp.where(idx != I_INF, conns[idx, 0], jnp.nan)
|
||||
o_key = jnp.where(idx != I_INF, conns[idx, 1], jnp.nan)
|
||||
|
||||
return i_key, o_key, idx
|
||||
Reference in New Issue
Block a user