1. Add origin_node and origin_conn.
2. Change the behavior of crossover and mutation. Now, TensorNEAT will use all fix_attrs(include historical marker if it has one) as identifier for gene in crossover and distance calculation.
3. Other slightly change.
4. Add two related examples: xor_origin and hopper_origin
5. Add related test file.
This commit is contained in:
wls2002
2024-12-18 16:20:34 +08:00
parent e9a8110af5
commit ee1a2a8271
18 changed files with 667 additions and 204 deletions

View File

@@ -0,0 +1,49 @@
from tensorneat.pipeline import Pipeline
from tensorneat.algorithm.neat import NEAT
from tensorneat.genome import DefaultGenome, OriginNode, OriginConn
from tensorneat.problem.rl import BraxEnv
from tensorneat.common import ACT, AGG
"""
Solving Hopper with OriginGene
See https://github.com/EMI-Group/tensorneat/issues/11
"""
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
pop_size=1000,
species_size=20,
survival_threshold=0.1,
compatibility_threshold=1.0,
genome=DefaultGenome(
num_inputs=11,
num_outputs=3,
init_hidden_layers=(),
# origin node gene, which use the same crossover behavior to the origin NEAT paper
node_gene=OriginNode(
activation_options=ACT.tanh,
aggregation_options=AGG.sum,
response_lower_bound = 1,
response_upper_bound= 1, # fix response to 1
),
# use origin connection, which using historical marker
conn_gene=OriginConn(),
output_transform=ACT.tanh,
),
),
problem=BraxEnv(
env_name="hopper",
max_step=1000,
),
seed=42,
generation_limit=100,
fitness_target=5000,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -30,7 +30,8 @@ pipeline.show(state, best)
# visualize the best individual # visualize the best individual
network = algorithm.genome.network_dict(state, *best) network = algorithm.genome.network_dict(state, *best)
algorithm.genome.visualize(network, save_path="./imgs/xor_network.svg") print(algorithm.genome.repr(state, *best))
# algorithm.genome.visualize(network, save_path="./imgs/xor_network.svg")
# transform the best individual to latex formula # transform the best individual to latex formula
from tensorneat.common.sympy_tools import to_latex_code, to_python_code from tensorneat.common.sympy_tools import to_latex_code, to_python_code

View File

@@ -0,0 +1,55 @@
from tensorneat.pipeline import Pipeline
from tensorneat import algorithm, genome, problem
from tensorneat.genome import OriginNode, OriginConn
from tensorneat.common import ACT
"""
Solving XOR-3d problem with OriginGene
See https://github.com/EMI-Group/tensorneat/issues/11
"""
algorithm = algorithm.NEAT(
pop_size=10000,
species_size=20,
survival_threshold=0.01,
genome=genome.DefaultGenome(
node_gene=OriginNode(),
conn_gene=OriginConn(),
num_inputs=3,
num_outputs=1,
max_nodes=7,
output_transform=ACT.sigmoid,
),
)
problem = problem.XOR3d()
pipeline = Pipeline(
algorithm,
problem,
generation_limit=200,
fitness_target=-1e-6,
seed=42,
)
state = pipeline.setup()
# run until terminate
state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best)
# visualize the best individual
network = algorithm.genome.network_dict(state, *best)
print(algorithm.genome.repr(state, *best))
# algorithm.genome.visualize(network, save_path="./imgs/xor_network.svg")
# transform the best individual to latex formula
from tensorneat.common.sympy_tools import to_latex_code, to_python_code
sympy_res = algorithm.genome.sympy_func(
state, network, sympy_output_transform=ACT.obtain_sympy(ACT.sigmoid)
)
latex_code = to_latex_code(*sympy_res)
print(latex_code)
# transform the best individual to python code
python_code = to_python_code(*sympy_res)
print(python_code)

View File

@@ -2,7 +2,7 @@ from jax import vmap
import numpy as np import numpy as np
from .base import BaseSubstrate from .base import BaseSubstrate
from tensorneat.genome.utils import set_conn_attrs from tensorneat.genome.utils import set_gene_attrs
class DefaultSubstrate(BaseSubstrate): class DefaultSubstrate(BaseSubstrate):
@@ -21,7 +21,8 @@ class DefaultSubstrate(BaseSubstrate):
def make_conns(self, query_res): def make_conns(self, query_res):
# change weight of conns # change weight of conns
return vmap(set_conn_attrs)(self.conns, query_res) # the last column is the weight
return self.conns.at[:, -1].set(query_res)
@property @property
def query_coors(self): def query_coors(self):

View File

@@ -116,6 +116,25 @@ class NEAT(BaseAlgorithm):
next_node_key = max_node_key + 1 next_node_key = max_node_key + 1
new_node_keys = jnp.arange(self.pop_size) + next_node_key new_node_keys = jnp.arange(self.pop_size) + next_node_key
# find next conn historical markers for mutation if needed
if "historical_marker" in self.genome.conn_gene.fixed_attrs:
all_conns_markers = vmap(
self.genome.conn_gene.get_historical_marker, in_axes=(None, 0)
)(state, state.pop_conns)
max_conn_markers = jnp.max(
all_conns_markers, where=~jnp.isnan(all_conns_markers), initial=0
)
next_conn_markers = max_conn_markers + 1
new_conn_markers = (
jnp.arange(self.pop_size * 3).reshape(self.pop_size, 3)
+ next_conn_markers
)
else:
# no need to generate new conn historical markers
# use 0
new_conn_markers = jnp.full((self.pop_size, 3), 0)
# prepare random keys # prepare random keys
k1, k2, randkey = jax.random.split(state.randkey, 3) k1, k2, randkey = jax.random.split(state.randkey, 3)
crossover_randkeys = jax.random.split(k1, self.pop_size) crossover_randkeys = jax.random.split(k1, self.pop_size)
@@ -133,9 +152,9 @@ class NEAT(BaseAlgorithm):
# batch mutation # batch mutation
m_n_nodes, m_n_conns = vmap( m_n_nodes, m_n_conns = vmap(
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0) self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0, 0)
)( )(
state, mutate_randkeys, n_nodes, n_conns, new_node_keys state, mutate_randkeys, n_nodes, n_conns, new_node_keys, new_conn_markers
) # mutated_new_nodes, mutated_new_conns ) # mutated_new_nodes, mutated_new_conns
# elitism don't mutate # elitism don't mutate

View File

@@ -113,8 +113,12 @@ class BaseGenome(StatefulBaseClass):
def visualize(self): def visualize(self):
raise NotImplementedError raise NotImplementedError
def execute_mutation(self, state, randkey, nodes, conns, new_node_key): def execute_mutation(
return self.mutation(state, self, randkey, nodes, conns, new_node_key) self, state, randkey, nodes, conns, new_node_key, new_conn_keys
):
return self.mutation(
state, self, randkey, nodes, conns, new_node_key, new_conn_keys
)
def execute_crossover(self, state, randkey, nodes1, conns1, nodes2, conns2): def execute_crossover(self, state, randkey, nodes1, conns1, nodes2, conns2):
return self.crossover(state, self, randkey, nodes1, conns1, nodes2, conns2) return self.crossover(state, self, randkey, nodes1, conns1, nodes2, conns2)
@@ -144,19 +148,31 @@ class BaseGenome(StatefulBaseClass):
conns = jnp.full((self.max_conns, self.conn_gene.length), jnp.nan) conns = jnp.full((self.max_conns, self.conn_gene.length), jnp.nan)
# create input and output indices # create input and output indices
conn_indices = self.all_init_conns conn_indices = self.all_init_conns
# create connection initial history markers
conn_markers = jnp.arange(all_conns_cnt)
# create conn attrs # create conn attrs
rand_keys_c = jax.random.split(k2, num=all_conns_cnt) rand_keys_c = jax.random.split(k2, num=all_conns_cnt)
conns_attr_func = jax.vmap( conns_attrs = jax.vmap(
self.conn_gene.new_random_attrs, self.conn_gene.new_random_attrs,
in_axes=( in_axes=(
None, None,
0, 0,
), ),
) )(state, rand_keys_c)
conns_attrs = conns_attr_func(state, rand_keys_c)
conns = conns.at[:all_conns_cnt, :2].set(conn_indices) # set conn indices # set conn indices
conns = conns.at[:all_conns_cnt, 2:].set(conns_attrs) # set conn attrs conns = conns.at[:all_conns_cnt, :2].set(conn_indices)
# set conn history markers if needed
if "historical_marker" in self.conn_gene.fixed_attrs:
conns = conns.at[:all_conns_cnt, 2].set(conn_markers)
# set conn attrs
conns = conns.at[:all_conns_cnt, len(self.conn_gene.fixed_attrs) :].set(
conns_attrs
)
return nodes, conns return nodes, conns

View File

@@ -8,7 +8,7 @@ import sympy as sp
from .base import BaseGenome from .base import BaseGenome
from .gene import DefaultNode, DefaultConn from .gene import DefaultNode, DefaultConn
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs from .utils import unflatten_conns, extract_gene_attrs, extract_gene_attrs
from tensorneat.common import ( from tensorneat.common import (
topological_sort, topological_sort,
@@ -16,7 +16,7 @@ from tensorneat.common import (
I_INF, I_INF,
attach_with_inf, attach_with_inf,
ACT, ACT,
AGG AGG,
) )
@@ -73,8 +73,8 @@ class DefaultGenome(BaseGenome):
ini_vals = jnp.full((self.max_nodes,), jnp.nan) ini_vals = jnp.full((self.max_nodes,), jnp.nan)
ini_vals = ini_vals.at[self.input_idx].set(inputs) ini_vals = ini_vals.at[self.input_idx].set(inputs)
nodes_attrs = vmap(extract_node_attrs)(nodes) nodes_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))(self.node_gene, nodes)
conns_attrs = vmap(extract_conn_attrs)(conns) conns_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))(self.conn_gene, conns)
def cond_fun(carry): def cond_fun(carry):
values, idx = carry values, idx = carry

View File

@@ -1,2 +1,3 @@
from .base import BaseConn from .base import BaseConn
from .default import DefaultConn from .default import DefaultConn
from .origin import OriginConn

View File

@@ -0,0 +1,60 @@
import jax, jax.numpy as jnp
from .default import DefaultConn
class OriginConn(DefaultConn):
"""
Implementation of connections in origin NEAT Paper.
Details at https://github.com/EMI-Group/tensorneat/issues/11.
"""
# add historical_marker into fixed_attrs
fixed_attrs = ["input_index", "output_index", "historical_marker"]
custom_attrs = ["weight"]
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
def crossover(self, state, randkey, attrs1, attrs2):
# random pick one of attrs, without attrs exchange
return jnp.where(
# origin code, generate multiple random numbers, without attrs exchange
# jax.random.normal(randkey, attrs1.shape) > 0,
jax.random.normal(randkey)
> 0, # generate one random number, without attrs exchange
attrs1,
attrs2,
)
def get_historical_marker(self, state, gene_array):
return gene_array[2]
def repr(self, state, conn, precision=2, idx_width=3, func_width=8):
in_idx, out_idx, historical_marker, weight = conn
in_idx = int(in_idx)
out_idx = int(out_idx)
historical_marker = int(historical_marker)
weight = round(float(weight), precision)
return "{}(in: {:<{idx_width}}, out: {:<{idx_width}}, historical_marker: {:<{idx_width}}, weight: {:<{float_width}})".format(
self.__class__.__name__,
in_idx,
out_idx,
historical_marker,
weight,
idx_width=idx_width,
float_width=precision + 3,
)
def to_dict(self, state, conn):
return {
"in": int(conn[0]),
"out": int(conn[1]),
"historical_marker": int(conn[2]),
"weight": jnp.float32(conn[3]),
}

View File

@@ -1,3 +1,4 @@
from .base import BaseNode from .base import BaseNode
from .default import DefaultNode from .default import DefaultNode
from .bias import BiasNode from .bias import BiasNode
from .origin import OriginNode

View File

@@ -0,0 +1,27 @@
import jax, jax.numpy as jnp
from .default import DefaultNode
class OriginNode(DefaultNode):
"""
Implementation of nodes in origin NEAT Paper.
Details at https://github.com/EMI-Group/tensorneat/issues/11.
"""
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
def crossover(self, state, randkey, attrs1, attrs2):
# random pick one of attrs, without attrs exchange
return jnp.where(
# origin code, generate multiple random numbers, without attrs exchange
# jax.random.normal(randkey, attrs1.shape) > 0,
jax.random.normal(randkey)
> 0, # generate one random number, without attrs exchange
attrs1,
attrs2,
)

View File

@@ -2,12 +2,10 @@ import jax
from jax import vmap, numpy as jnp from jax import vmap, numpy as jnp
from .base import BaseCrossover from .base import BaseCrossover
from ...utils import ( from ...utils import extract_gene_attrs, set_gene_attrs
extract_node_attrs,
extract_conn_attrs, from tensorneat.common import fetch_first, I_INF
set_node_attrs, from tensorneat.genome.gene import BaseGene
set_conn_attrs,
)
class DefaultCrossover(BaseCrossover): class DefaultCrossover(BaseCrossover):
@@ -17,71 +15,90 @@ class DefaultCrossover(BaseCrossover):
notice that genome1 should have higher fitness than genome2 (genome1 is winner!) notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
""" """
randkey1, randkey2 = jax.random.split(randkey, 2) randkey1, randkey2 = jax.random.split(randkey, 2)
randkeys1 = jax.random.split(randkey1, genome.max_nodes) node_randkeys = jax.random.split(randkey1, genome.max_nodes)
randkeys2 = jax.random.split(randkey2, genome.max_conns) conn_randkeys = jax.random.split(randkey2, genome.max_conns)
batch_create_new_gene = jax.vmap(
create_new_gene, in_axes=(None, 0, None, 0, 0, None, None)
)
# crossover nodes # crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0] node_keys1, node_keys2 = (
# make homologous genes align in nodes2 align with nodes1 nodes1[:, 0 : len(genome.node_gene.fixed_attrs)],
nodes2 = self.align_array(keys1, keys2, nodes2, is_conn=False) nodes2[:, 0 : len(genome.node_gene.fixed_attrs)],
)
# For not homologous genes, use the value of nodes1(winner) node_attrs1 = vmap(extract_gene_attrs, in_axes=(None, 0))(
# For homologous genes, use the crossover result between nodes1 and nodes2 genome.node_gene, nodes1
node_attrs1 = vmap(extract_node_attrs)(nodes1) )
node_attrs2 = vmap(extract_node_attrs)(nodes2) node_attrs2 = vmap(extract_gene_attrs, in_axes=(None, 0))(
genome.node_gene, nodes2
new_node_attrs = jnp.where( )
jnp.isnan(node_attrs1) | jnp.isnan(node_attrs2), # one of them is nan
node_attrs1, # not homologous genes or both nan, use the value of nodes1(winner) new_node_attrs = batch_create_new_gene(
vmap(genome.node_gene.crossover, in_axes=(None, 0, 0, 0))( state,
state, randkeys1, node_attrs1, node_attrs2 node_randkeys,
), # homologous or both nan genome.node_gene,
node_keys1,
node_attrs1,
node_keys2,
node_attrs2,
)
new_nodes = vmap(set_gene_attrs, in_axes=(None, 0, 0))(
genome.node_gene, nodes1, new_node_attrs
) )
new_nodes = vmap(set_node_attrs)(nodes1, new_node_attrs)
# crossover connections # crossover connections
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2] # all fixed_attrs together will use to identify a connection
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True) # if using historical marker, use it
# related to issue: https://github.com/EMI-Group/tensorneat/issues/11
conns_attrs1 = vmap(extract_conn_attrs)(conns1) conn_keys1, conn_keys2 = (
conns_attrs2 = vmap(extract_conn_attrs)(conns2) conns1[:, 0 : len(genome.conn_gene.fixed_attrs)],
conns2[:, 0 : len(genome.conn_gene.fixed_attrs)],
new_conn_attrs = jnp.where( )
jnp.isnan(conns_attrs1) | jnp.isnan(conns_attrs2), conn_attrs1 = vmap(extract_gene_attrs, in_axes=(None, 0))(
conns_attrs1, # not homologous genes or both nan, use the value of conns1(winner) genome.conn_gene, conns1
vmap(genome.conn_gene.crossover, in_axes=(None, 0, 0, 0))( )
state, randkeys2, conns_attrs1, conns_attrs2 conn_attrs2 = vmap(extract_gene_attrs, in_axes=(None, 0))(
), # homologous or both nan genome.conn_gene, conns2
)
new_conn_attrs = batch_create_new_gene(
state,
conn_randkeys,
genome.conn_gene,
conn_keys1,
conn_attrs1,
conn_keys2,
conn_attrs2,
)
new_conns = vmap(set_gene_attrs, in_axes=(None, 0, 0))(
genome.conn_gene, conns1, new_conn_attrs
) )
new_conns = vmap(set_conn_attrs)(conns1, new_conn_attrs)
return new_nodes, new_conns return new_nodes, new_conns
def align_array(self, seq1, seq2, ar2, is_conn: bool):
"""
After I review this code, I found that it is the most difficult part of the code.
Please consider carefully before change it!
make ar2 align with ar1.
:param seq1:
:param seq2:
:param ar2:
:param is_conn:
:return:
align means to intersect part of ar2 will be at the same position as ar1,
non-intersect part of ar2 will be set to Nan
"""
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
if is_conn: def create_new_gene(
mask = jnp.all(mask, axis=2) state,
randkey,
gene: BaseGene,
gene_key,
gene_attrs,
genes_keys,
genes_attrs,
):
# find homologous genes
homologous_idx = fetch_first(jnp.all(gene_key == genes_keys, axis=1))
intersect_mask = mask.any(axis=1) def none(): # no homologous, use winner's gene
idx = jnp.arange(0, len(seq1)) return gene_attrs
idx_fixed = jnp.dot(mask, idx)
refactor_ar2 = jnp.where( def crossover(): # when homologous gene is found, execute crossover
intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan return gene.crossover(state, randkey, gene_attrs, genes_attrs[homologous_idx])
new_attrs = jax.lax.cond(
homologous_idx == I_INF, # homologous gene is not found or current gene is nan
none,
crossover,
) )
return refactor_ar2 return new_attrs

View File

@@ -1,7 +1,8 @@
from jax import vmap, numpy as jnp from jax import vmap, numpy as jnp
from .base import BaseDistance from .base import BaseDistance
from ...utils import extract_node_attrs, extract_conn_attrs from ...gene import BaseGene
from ...utils import extract_gene_attrs
class DefaultDistance(BaseDistance): class DefaultDistance(BaseDistance):
@@ -17,83 +18,47 @@ class DefaultDistance(BaseDistance):
""" """
The distance between two genomes The distance between two genomes
""" """
d = self.node_distance(state, genome, nodes1, nodes2) + self.conn_distance( node_distance = self.gene_distance(state, genome.node_gene, nodes1, nodes2)
state, genome, conns1, conns2 conn_distance = self.gene_distance(state, genome.conn_gene, conns1, conns2)
) return node_distance + conn_distance
return d
def node_distance(self, state, genome, nodes1, nodes2):
def gene_distance(self, state, gene: BaseGene, genes1, genes2):
""" """
The distance of the nodes part for two genomes The distance between to genes
genes1: 2-D jax array with shape
genes2: 2-D jax array with shape
gene1.shape == gene2.shape
""" """
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0])) cnt1 = jnp.sum(~jnp.isnan(genes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0])) cnt2 = jnp.sum(~jnp.isnan(genes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2) max_cnt = jnp.maximum(cnt1, cnt2)
# align homologous nodes # align homologous nodes
# this process is similar to np.intersect1d. # this process is similar to np.intersect1d in higher dimension
nodes = jnp.concatenate((nodes1, nodes2), axis=0) total_genes = jnp.concatenate((genes1, genes2), axis=0)
keys = nodes[:, 0] identifiers = total_genes[:, : len(gene.fixed_attrs)]
sorted_indices = jnp.argsort(keys, axis=0) sorted_identifiers = jnp.lexsort(identifiers.T[::-1])
nodes = nodes[sorted_indices] total_genes = total_genes[sorted_identifiers]
nodes = jnp.concatenate( total_genes = jnp.concatenate(
[nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0 [total_genes, jnp.full((1, total_genes.shape[1]), jnp.nan)], axis=0
) # add a nan row to the end ) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row fr, sr = total_genes[:-1], total_genes[1:] # first row, second row
# flag location of homologous nodes # intersect part of two genes
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0]) intersect_mask = jnp.all(
fr[:, : len(gene.fixed_attrs)] == sr[:, : len(gene.fixed_attrs)], axis=1
) & ~jnp.isnan(fr[:, 0])
# calculate the count of non_homologous of two genomes non_homologous_cnt = cnt1 + cnt2 - 2 * jnp.sum(intersect_mask)
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
# calculate the distance of homologous nodes fr_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))(gene, fr)
fr_attrs = vmap(extract_node_attrs)(fr) sr_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))(gene, sr)
sr_attrs = vmap(extract_node_attrs)(sr)
hnd = vmap(genome.node_gene.distance, in_axes=(None, 0, 0))( # homologous gene distance
state, fr_attrs, sr_attrs hgd = vmap(gene.distance, in_axes=(None, 0, 0))(state, fr_attrs, sr_attrs)
) # homologous node distance hgd = jnp.where(jnp.isnan(hgd), 0, hgd)
hnd = jnp.where(jnp.isnan(hnd), 0, hnd) homologous_distance = jnp.sum(hgd * intersect_mask)
homologous_distance = jnp.sum(hnd * intersect_mask)
val = (
non_homologous_cnt * self.compatibility_disjoint
+ homologous_distance * self.compatibility_weight
)
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
return val
def conn_distance(self, state, genome, conns1, conns2):
"""
The distance of the conns part for two genomes
"""
con_cnt1 = jnp.sum(~jnp.isnan(conns1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(conns2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
cons = jnp.concatenate((conns1, conns2), axis=0)
keys = cons[:, :2]
sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate(
[cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0
) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
fr_attrs = vmap(extract_conn_attrs)(fr)
sr_attrs = vmap(extract_conn_attrs)(sr)
hcd = vmap(genome.conn_gene.distance, in_axes=(None, 0, 0))(
state, fr_attrs, sr_attrs
) # homologous connection distance
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask)
val = ( val = (
non_homologous_cnt * self.compatibility_disjoint non_homologous_cnt * self.compatibility_disjoint

View File

@@ -3,5 +3,5 @@ from tensorneat.common import StatefulBaseClass, State
class BaseMutation(StatefulBaseClass): class BaseMutation(StatefulBaseClass):
def __call__(self, state, genome, randkey, nodes, conns, new_node_key): def __call__(self, state, genome, randkey, nodes, conns, new_node_key, new_conn_key):
raise NotImplementedError raise NotImplementedError

View File

@@ -13,10 +13,8 @@ from ...utils import (
add_conn, add_conn,
delete_node_by_pos, delete_node_by_pos,
delete_conn_by_pos, delete_conn_by_pos,
extract_node_attrs, extract_gene_attrs,
extract_conn_attrs, set_gene_attrs
set_node_attrs,
set_conn_attrs,
) )
@@ -33,17 +31,28 @@ class DefaultMutation(BaseMutation):
self.node_add = node_add self.node_add = node_add
self.node_delete = node_delete self.node_delete = node_delete
def __call__(self, state, genome, randkey, nodes, conns, new_node_key): def __call__(
self, state, genome, randkey, nodes, conns, new_node_key, new_conn_key
):
assert (
new_node_key.shape == ()
) # scalar, as there is max one new node in each mutation
assert new_conn_key.shape == (
3,
) # there are max 3 new connections (mutate add node + mutate add conn)
k1, k2 = jax.random.split(randkey) k1, k2 = jax.random.split(randkey)
nodes, conns = self.mutate_structure( nodes, conns = self.mutate_structure(
state, genome, k1, nodes, conns, new_node_key state, genome, k1, nodes, conns, new_node_key, new_conn_key
) )
nodes, conns = self.mutate_values(state, genome, k2, nodes, conns) nodes, conns = self.mutate_values(state, genome, k2, nodes, conns)
return nodes, conns return nodes, conns
def mutate_structure(self, state, genome, randkey, nodes, conns, new_node_key): def mutate_structure(
self, state, genome, randkey, nodes, conns, new_node_key, new_conn_key
):
def mutate_add_node(key_, nodes_, conns_): def mutate_add_node(key_, nodes_, conns_):
""" """
add a node while do not influence the output of the network add a node while do not influence the output of the network
@@ -57,27 +66,33 @@ class DefaultMutation(BaseMutation):
def successful_add_node(): def successful_add_node():
# remove the original connection and record its attrs # remove the original connection and record its attrs
original_attrs = extract_conn_attrs(conns_[idx]) original_attrs = extract_gene_attrs(genome.conn_gene, conns_[idx])
new_conns = delete_conn_by_pos(conns_, idx) new_conns = delete_conn_by_pos(conns_, idx)
# add a new node with identity attrs # add a new node with identity attrs
new_nodes = add_node( new_nodes = add_node(
nodes_, new_node_key, genome.node_gene.new_identity_attrs(state) nodes_, jnp.array([new_node_key]), genome.node_gene.new_identity_attrs(state)
) )
# whether to use historical marker in connection
if "historical_marker" in genome.conn_gene.fixed_attrs:
fix_attrs1 = jnp.array([i_key, new_node_key, new_conn_key[0]])
fix_attrs2 = jnp.array([new_node_key, o_key, new_conn_key[1]])
else:
fix_attrs1 = jnp.array([i_key, new_node_key])
fix_attrs2 = jnp.array([new_node_key, o_key])
# add two new connections # add two new connections
# first is with identity attrs # first is with identity attrs
new_conns = add_conn( new_conns = add_conn(
new_conns, new_conns,
i_key, fix_attrs1,
new_node_key,
genome.conn_gene.new_identity_attrs(state), genome.conn_gene.new_identity_attrs(state),
) )
# second is with the origin attrs # second is with the origin attrs
new_conns = add_conn( new_conns = add_conn(
new_conns, new_conns,
new_node_key, fix_attrs2,
o_key,
original_attrs, original_attrs,
) )
@@ -160,8 +175,12 @@ class DefaultMutation(BaseMutation):
def successful(): def successful():
# add a connection with zero attrs # add a connection with zero attrs
if "historical_marker" in genome.conn_gene.fixed_attrs:
new_fix_attrs = jnp.array([i_key, o_key, new_conn_key[2]])
else:
new_fix_attrs = jnp.array([i_key, o_key])
return nodes_, add_conn( return nodes_, add_conn(
conns_, i_key, o_key, genome.conn_gene.new_zero_attrs(state) conns_, new_fix_attrs, genome.conn_gene.new_zero_attrs(state)
) )
if genome.network_type == "feedforward": if genome.network_type == "feedforward":
@@ -228,17 +247,25 @@ class DefaultMutation(BaseMutation):
nodes_randkeys = jax.random.split(k1, num=genome.max_nodes) nodes_randkeys = jax.random.split(k1, num=genome.max_nodes)
conns_randkeys = jax.random.split(k2, num=genome.max_conns) conns_randkeys = jax.random.split(k2, num=genome.max_conns)
node_attrs = vmap(extract_node_attrs)(nodes) node_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))(
genome.node_gene, nodes
)
new_node_attrs = vmap(genome.node_gene.mutate, in_axes=(None, 0, 0))( new_node_attrs = vmap(genome.node_gene.mutate, in_axes=(None, 0, 0))(
state, nodes_randkeys, node_attrs state, nodes_randkeys, node_attrs
) )
new_nodes = vmap(set_node_attrs)(nodes, new_node_attrs) new_nodes = vmap(set_gene_attrs, in_axes=(None, 0, 0))(
genome.node_gene, nodes, new_node_attrs
)
conn_attrs = vmap(extract_conn_attrs)(conns) conn_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))(
genome.conn_gene, conns
)
new_conn_attrs = vmap(genome.conn_gene.mutate, in_axes=(None, 0, 0))( new_conn_attrs = vmap(genome.conn_gene.mutate, in_axes=(None, 0, 0))(
state, conns_randkeys, conn_attrs state, conns_randkeys, conn_attrs
) )
new_conns = vmap(set_conn_attrs)(conns, new_conn_attrs) new_conns = vmap(set_gene_attrs, in_axes=(None, 0, 0))(
genome.conn_gene, conns, new_conn_attrs
)
# nan nodes not changed # nan nodes not changed
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes) new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)

View File

@@ -5,7 +5,7 @@ from .utils import unflatten_conns
from .base import BaseGenome from .base import BaseGenome
from .gene import DefaultNode, DefaultConn from .gene import DefaultNode, DefaultConn
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs from .utils import unflatten_conns, extract_gene_attrs, extract_gene_attrs
from tensorneat.common import attach_with_inf from tensorneat.common import attach_with_inf
@@ -55,8 +55,8 @@ class RecurrentGenome(BaseGenome):
vals = jnp.full((self.max_nodes,), jnp.nan) vals = jnp.full((self.max_nodes,), jnp.nan)
nodes_attrs = vmap(extract_node_attrs)(nodes) nodes_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))(self.node_gene, nodes)
conns_attrs = vmap(extract_conn_attrs)(conns) conns_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))(self.conn_gene, conns)
expand_conns_attrs = attach_with_inf(conns_attrs, u_conns) expand_conns_attrs = attach_with_inf(conns_attrs, u_conns)
def body_func(_, values): def body_func(_, values):

View File

@@ -2,6 +2,7 @@ import jax
from jax import vmap, numpy as jnp from jax import vmap, numpy as jnp
import numpy as np import numpy as np
from .gene import BaseGene
from tensorneat.common import fetch_first, I_INF from tensorneat.common import fetch_first, I_INF
@@ -38,49 +39,27 @@ def valid_cnt(nodes_or_conns):
return jnp.sum(~jnp.isnan(nodes_or_conns[:, 0])) return jnp.sum(~jnp.isnan(nodes_or_conns[:, 0]))
def extract_node_attrs(node): def extract_gene_attrs(gene: BaseGene, gene_array):
""" """
node: Array(NL, ) extract the custom attributes of the gene
extract the attributes of a node
""" """
return node[1:] # 0 is for idx return gene_array[len(gene.fixed_attrs) :]
def set_node_attrs(node, attrs): def set_gene_attrs(gene: BaseGene, gene_array, attrs):
""" """
node: Array(NL, ) set the custom attributes of the gene
attrs: Array(NL-1, )
set the attributes of a node
""" """
return node.at[1:].set(attrs) # 0 is for idx return gene_array.at[len(gene.fixed_attrs) :].set(attrs)
def extract_conn_attrs(conn): def add_node(nodes, fix_attrs, custom_attrs):
"""
conn: Array(CL, )
extract the attributes of a connection
"""
return conn[2:] # 0, 1 is for in-idx and out-idx
def set_conn_attrs(conn, attrs):
"""
conn: Array(CL, )
attrs: Array(CL-2, )
set the attributes of a connection
"""
return conn.at[2:].set(attrs) # 0, 1 is for in-idx and out-idx
def add_node(nodes, new_key: int, attrs):
""" """
Add a new node to the genome. Add a new node to the genome.
The new node will place at the first NaN row. The new node will place at the first NaN row.
""" """
exist_keys = nodes[:, 0] pos = fetch_first(jnp.isnan(nodes[:, 0]))
pos = fetch_first(jnp.isnan(exist_keys)) return nodes.at[pos].set(jnp.concatenate((fix_attrs, custom_attrs)))
new_nodes = nodes.at[pos, 0].set(new_key)
return new_nodes.at[pos, 1:].set(attrs)
def delete_node_by_pos(nodes, pos): def delete_node_by_pos(nodes, pos):
@@ -91,15 +70,13 @@ def delete_node_by_pos(nodes, pos):
return nodes.at[pos].set(jnp.nan) return nodes.at[pos].set(jnp.nan)
def add_conn(conns, i_key, o_key, attrs): def add_conn(conns, fix_attrs, custom_attrs):
""" """
Add a new connection to the genome. Add a new connection to the genome.
The new connection will place at the first NaN row. The new connection will place at the first NaN row.
""" """
con_keys = conns[:, 0] pos = fetch_first(jnp.isnan(conns[:, 0]))
pos = fetch_first(jnp.isnan(con_keys)) return conns.at[pos].set(jnp.concatenate((fix_attrs, custom_attrs)))
new_conns = conns.at[pos, 0:2].set(jnp.array([i_key, o_key]))
return new_conns.at[pos, 2:].set(attrs)
def delete_conn_by_pos(conns, pos): def delete_conn_by_pos(conns, pos):

View File

@@ -0,0 +1,247 @@
import jax, jax.numpy as jnp
from tensorneat.genome.operations import (
DefaultMutation,
DefaultDistance,
DefaultCrossover,
)
from tensorneat.genome import (
DefaultGenome,
DefaultNode,
DefaultConn,
OriginNode,
OriginConn,
)
from tensorneat.genome.utils import add_node, add_conn
origin_genome = DefaultGenome(
node_gene=OriginNode(response_init_std=1),
conn_gene=OriginConn(),
mutation=DefaultMutation(conn_add=1, node_add=1, conn_delete=0, node_delete=0),
crossover=DefaultCrossover(),
distance=DefaultDistance(),
num_inputs=3,
num_outputs=1,
max_nodes=6,
max_conns=6,
)
default_genome = DefaultGenome(
node_gene=DefaultNode(response_init_std=1),
conn_gene=DefaultConn(),
mutation=DefaultMutation(conn_add=1, node_add=1, conn_delete=0, node_delete=0),
crossover=DefaultCrossover(),
distance=DefaultDistance(),
num_inputs=3,
num_outputs=1,
max_nodes=6,
max_conns=6,
)
state = default_genome.setup()
state = origin_genome.setup(state)
randkey = jax.random.PRNGKey(42)
def mutation_default():
nodes, conns = default_genome.initialize(state, randkey)
print("old genome:\n", default_genome.repr(state, nodes, conns))
nodes, conns = default_genome.execute_mutation(
state,
randkey,
nodes,
conns,
new_node_key=jnp.asarray(10),
new_conn_keys=jnp.array([20, 21, 22]),
)
# new_conn_keys is not used in default genome
print("new genome:\n", default_genome.repr(state, nodes, conns))
def mutation_origin():
nodes, conns = origin_genome.initialize(state, randkey)
print(conns)
print("old genome:\n", origin_genome.repr(state, nodes, conns))
nodes, conns = origin_genome.execute_mutation(
state,
randkey,
nodes,
conns,
new_node_key=jnp.asarray(10),
new_conn_keys=jnp.array([20, 21, 22]),
)
print(conns)
# new_conn_keys is used in origin genome
print("new genome:\n", origin_genome.repr(state, nodes, conns))
def distance_default():
nodes, conns = default_genome.initialize(state, randkey)
nodes = add_node(
nodes,
fix_attrs=jnp.asarray([10]),
custom_attrs=default_genome.node_gene.new_identity_attrs(state)
)
conns1 = add_conn(
conns,
fix_attrs=jnp.array([0, 10]), # in-idx, out-idx
custom_attrs=default_genome.conn_gene.new_zero_attrs(state)
)
conns2 = add_conn(
conns,
fix_attrs=jnp.array([0, 10]), # in-idx, out-idx
custom_attrs=default_genome.conn_gene.new_random_attrs(state, randkey)
)
print("genome1:\n", default_genome.repr(state, nodes, conns1))
print("genome2:\n", default_genome.repr(state, nodes, conns2))
distance = default_genome.execute_distance(state, nodes, conns1, nodes, conns2)
print("distance: ", distance)
def distance_origin_case1():
"""
distance with different historical marker
"""
nodes, conns = origin_genome.initialize(state, randkey)
nodes = add_node(
nodes,
fix_attrs=jnp.asarray([10]),
custom_attrs=origin_genome.node_gene.new_identity_attrs(state)
)
conns1 = add_conn(
conns,
fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark
custom_attrs=origin_genome.conn_gene.new_zero_attrs(state)
)
conns2 = add_conn(
conns,
fix_attrs=jnp.array([0, 10, 88]), # in-idx, out-idx, historical mark
custom_attrs=origin_genome.conn_gene.new_random_attrs(state, randkey)
)
print("genome1:\n", origin_genome.repr(state, nodes, conns1))
print("genome2:\n", origin_genome.repr(state, nodes, conns2))
distance = origin_genome.execute_distance(state, nodes, conns1, nodes, conns2)
print("distance: ", distance)
def distance_origin_case2():
"""
distance with same historical marker
"""
nodes, conns = origin_genome.initialize(state, randkey)
nodes = add_node(
nodes,
fix_attrs=jnp.asarray([10]),
custom_attrs=origin_genome.node_gene.new_identity_attrs(state)
)
conns1 = add_conn(
conns,
fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark
custom_attrs=origin_genome.conn_gene.new_zero_attrs(state)
)
conns2 = add_conn(
conns,
fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark
custom_attrs=origin_genome.conn_gene.new_random_attrs(state, randkey)
)
print("genome1:\n", origin_genome.repr(state, nodes, conns1))
print("genome2:\n", origin_genome.repr(state, nodes, conns2))
distance = origin_genome.execute_distance(state, nodes, conns1, nodes, conns2)
print("distance: ", distance)
def crossover_origin_case1():
"""
crossover with different historical marker
"""
nodes, conns = origin_genome.initialize(state, randkey)
nodes = add_node(
nodes,
fix_attrs=jnp.asarray([10]),
custom_attrs=origin_genome.node_gene.new_identity_attrs(state)
)
conns1 = add_conn(
conns,
fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark
custom_attrs=origin_genome.conn_gene.new_zero_attrs(state)
)
conns2 = add_conn(
conns,
fix_attrs=jnp.array([0, 10, 88]), # in-idx, out-idx, historical mark
custom_attrs=origin_genome.conn_gene.new_random_attrs(state, randkey)
)
print("genome1:\n", origin_genome.repr(state, nodes, conns1))
print("genome2:\n", origin_genome.repr(state, nodes, conns2))
# (0, 10)'s weight must be 0 (disjoint gene, use fitter)
child_nodes, child_conns = origin_genome.execute_crossover(state, randkey, nodes, conns1, nodes, conns2)
print("child:\n", origin_genome.repr(state, child_nodes, child_conns))
def crossover_origin_case2():
"""
crossover with same historical marker
"""
nodes, conns = origin_genome.initialize(state, randkey)
nodes = add_node(
nodes,
fix_attrs=jnp.asarray([10]),
custom_attrs=origin_genome.node_gene.new_identity_attrs(state)
)
conns1 = add_conn(
conns,
fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark
custom_attrs=origin_genome.conn_gene.new_zero_attrs(state)
)
conns2 = add_conn(
conns,
fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark
custom_attrs=origin_genome.conn_gene.new_random_attrs(state, randkey)
)
print("genome1:\n", origin_genome.repr(state, nodes, conns1))
print("genome2:\n", origin_genome.repr(state, nodes, conns2))
# (0, 10)'s weight might be random or zero (homologous gene)
# zero case:
child_nodes, child_conns = origin_genome.execute_crossover(state, jax.random.key(99), nodes, conns1, nodes, conns2)
print("child_zero:\n", origin_genome.repr(state, child_nodes, child_conns))
# random case:
child_nodes, child_conns = origin_genome.execute_crossover(state, jax.random.key(0), nodes, conns1, nodes, conns2)
print("child_random:\n", origin_genome.repr(state, child_nodes, child_conns))
def crossover_origin_case3():
"""
test examine it use random gene rather than attribute exchange
"""
nodes, conns = origin_genome.initialize(state, randkey)
nodes1 = add_node(
nodes,
fix_attrs=jnp.asarray([10]),
custom_attrs=jnp.array([1, 2, 0, 0])
)
nodes2 = add_node(
nodes,
fix_attrs=jnp.asarray([10]),
custom_attrs=jnp.array([100, 200, 0, 0])
)
# [1, 2] case
child_nodes, child_conns = origin_genome.execute_crossover(state, jax.random.key(99), nodes1, conns, nodes2, conns)
print("child1:\n", origin_genome.repr(state, child_nodes, child_conns))
# [100, 200] case
child_nodes, child_conns = origin_genome.execute_crossover(state, jax.random.key(1), nodes1, conns, nodes2, conns)
print("child2:\n", origin_genome.repr(state, child_nodes, child_conns))
if __name__ == "__main__":
# mutation_origin()
# distance_default()
# distance_origin_case1()
# distance_origin_case2()
# crossover_origin_case1()
# crossover_origin_case2()
crossover_origin_case3()