This commit is related to issue: https://github.com/EMI-Group/tensorneat/issues/11
1. Add origin_node and origin_conn. 2. Change the behavior of crossover and mutation. Now, TensorNEAT will use all fix_attrs(include historical marker if it has one) as identifier for gene in crossover and distance calculation. 3. Other slightly change. 4. Add two related examples: xor_origin and hopper_origin 5. Add related test file.
This commit is contained in:
49
examples/brax/hopper_origin.py
Normal file
49
examples/brax/hopper_origin.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from tensorneat.pipeline import Pipeline
|
||||
from tensorneat.algorithm.neat import NEAT
|
||||
from tensorneat.genome import DefaultGenome, OriginNode, OriginConn
|
||||
|
||||
from tensorneat.problem.rl import BraxEnv
|
||||
from tensorneat.common import ACT, AGG
|
||||
|
||||
"""
|
||||
Solving Hopper with OriginGene
|
||||
See https://github.com/EMI-Group/tensorneat/issues/11
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
pop_size=1000,
|
||||
species_size=20,
|
||||
survival_threshold=0.1,
|
||||
compatibility_threshold=1.0,
|
||||
genome=DefaultGenome(
|
||||
num_inputs=11,
|
||||
num_outputs=3,
|
||||
init_hidden_layers=(),
|
||||
# origin node gene, which use the same crossover behavior to the origin NEAT paper
|
||||
node_gene=OriginNode(
|
||||
activation_options=ACT.tanh,
|
||||
aggregation_options=AGG.sum,
|
||||
response_lower_bound = 1,
|
||||
response_upper_bound= 1, # fix response to 1
|
||||
),
|
||||
# use origin connection, which using historical marker
|
||||
conn_gene=OriginConn(),
|
||||
output_transform=ACT.tanh,
|
||||
),
|
||||
),
|
||||
problem=BraxEnv(
|
||||
env_name="hopper",
|
||||
max_step=1000,
|
||||
),
|
||||
seed=42,
|
||||
generation_limit=100,
|
||||
fitness_target=5000,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
@@ -30,7 +30,8 @@ pipeline.show(state, best)
|
||||
|
||||
# visualize the best individual
|
||||
network = algorithm.genome.network_dict(state, *best)
|
||||
algorithm.genome.visualize(network, save_path="./imgs/xor_network.svg")
|
||||
print(algorithm.genome.repr(state, *best))
|
||||
# algorithm.genome.visualize(network, save_path="./imgs/xor_network.svg")
|
||||
|
||||
# transform the best individual to latex formula
|
||||
from tensorneat.common.sympy_tools import to_latex_code, to_python_code
|
||||
|
||||
55
examples/func_fit/xor_origin.py
Normal file
55
examples/func_fit/xor_origin.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from tensorneat.pipeline import Pipeline
|
||||
from tensorneat import algorithm, genome, problem
|
||||
from tensorneat.genome import OriginNode, OriginConn
|
||||
from tensorneat.common import ACT
|
||||
|
||||
"""
|
||||
Solving XOR-3d problem with OriginGene
|
||||
See https://github.com/EMI-Group/tensorneat/issues/11
|
||||
"""
|
||||
|
||||
algorithm = algorithm.NEAT(
|
||||
pop_size=10000,
|
||||
species_size=20,
|
||||
survival_threshold=0.01,
|
||||
genome=genome.DefaultGenome(
|
||||
node_gene=OriginNode(),
|
||||
conn_gene=OriginConn(),
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
max_nodes=7,
|
||||
output_transform=ACT.sigmoid,
|
||||
),
|
||||
)
|
||||
problem = problem.XOR3d()
|
||||
|
||||
pipeline = Pipeline(
|
||||
algorithm,
|
||||
problem,
|
||||
generation_limit=200,
|
||||
fitness_target=-1e-6,
|
||||
seed=42,
|
||||
)
|
||||
state = pipeline.setup()
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
# show result
|
||||
pipeline.show(state, best)
|
||||
|
||||
# visualize the best individual
|
||||
network = algorithm.genome.network_dict(state, *best)
|
||||
print(algorithm.genome.repr(state, *best))
|
||||
# algorithm.genome.visualize(network, save_path="./imgs/xor_network.svg")
|
||||
|
||||
# transform the best individual to latex formula
|
||||
from tensorneat.common.sympy_tools import to_latex_code, to_python_code
|
||||
|
||||
sympy_res = algorithm.genome.sympy_func(
|
||||
state, network, sympy_output_transform=ACT.obtain_sympy(ACT.sigmoid)
|
||||
)
|
||||
latex_code = to_latex_code(*sympy_res)
|
||||
print(latex_code)
|
||||
|
||||
# transform the best individual to python code
|
||||
python_code = to_python_code(*sympy_res)
|
||||
print(python_code)
|
||||
@@ -2,7 +2,7 @@ from jax import vmap
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseSubstrate
|
||||
from tensorneat.genome.utils import set_conn_attrs
|
||||
from tensorneat.genome.utils import set_gene_attrs
|
||||
|
||||
|
||||
class DefaultSubstrate(BaseSubstrate):
|
||||
@@ -21,7 +21,8 @@ class DefaultSubstrate(BaseSubstrate):
|
||||
|
||||
def make_conns(self, query_res):
|
||||
# change weight of conns
|
||||
return vmap(set_conn_attrs)(self.conns, query_res)
|
||||
# the last column is the weight
|
||||
return self.conns.at[:, -1].set(query_res)
|
||||
|
||||
@property
|
||||
def query_coors(self):
|
||||
|
||||
@@ -116,6 +116,25 @@ class NEAT(BaseAlgorithm):
|
||||
next_node_key = max_node_key + 1
|
||||
new_node_keys = jnp.arange(self.pop_size) + next_node_key
|
||||
|
||||
# find next conn historical markers for mutation if needed
|
||||
if "historical_marker" in self.genome.conn_gene.fixed_attrs:
|
||||
all_conns_markers = vmap(
|
||||
self.genome.conn_gene.get_historical_marker, in_axes=(None, 0)
|
||||
)(state, state.pop_conns)
|
||||
|
||||
max_conn_markers = jnp.max(
|
||||
all_conns_markers, where=~jnp.isnan(all_conns_markers), initial=0
|
||||
)
|
||||
next_conn_markers = max_conn_markers + 1
|
||||
new_conn_markers = (
|
||||
jnp.arange(self.pop_size * 3).reshape(self.pop_size, 3)
|
||||
+ next_conn_markers
|
||||
)
|
||||
else:
|
||||
# no need to generate new conn historical markers
|
||||
# use 0
|
||||
new_conn_markers = jnp.full((self.pop_size, 3), 0)
|
||||
|
||||
# prepare random keys
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
crossover_randkeys = jax.random.split(k1, self.pop_size)
|
||||
@@ -133,9 +152,9 @@ class NEAT(BaseAlgorithm):
|
||||
|
||||
# batch mutation
|
||||
m_n_nodes, m_n_conns = vmap(
|
||||
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
|
||||
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0, 0)
|
||||
)(
|
||||
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
|
||||
state, mutate_randkeys, n_nodes, n_conns, new_node_keys, new_conn_markers
|
||||
) # mutated_new_nodes, mutated_new_conns
|
||||
|
||||
# elitism don't mutate
|
||||
|
||||
@@ -113,8 +113,12 @@ class BaseGenome(StatefulBaseClass):
|
||||
def visualize(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def execute_mutation(self, state, randkey, nodes, conns, new_node_key):
|
||||
return self.mutation(state, self, randkey, nodes, conns, new_node_key)
|
||||
def execute_mutation(
|
||||
self, state, randkey, nodes, conns, new_node_key, new_conn_keys
|
||||
):
|
||||
return self.mutation(
|
||||
state, self, randkey, nodes, conns, new_node_key, new_conn_keys
|
||||
)
|
||||
|
||||
def execute_crossover(self, state, randkey, nodes1, conns1, nodes2, conns2):
|
||||
return self.crossover(state, self, randkey, nodes1, conns1, nodes2, conns2)
|
||||
@@ -144,19 +148,31 @@ class BaseGenome(StatefulBaseClass):
|
||||
conns = jnp.full((self.max_conns, self.conn_gene.length), jnp.nan)
|
||||
# create input and output indices
|
||||
conn_indices = self.all_init_conns
|
||||
|
||||
# create connection initial history markers
|
||||
conn_markers = jnp.arange(all_conns_cnt)
|
||||
|
||||
# create conn attrs
|
||||
rand_keys_c = jax.random.split(k2, num=all_conns_cnt)
|
||||
conns_attr_func = jax.vmap(
|
||||
conns_attrs = jax.vmap(
|
||||
self.conn_gene.new_random_attrs,
|
||||
in_axes=(
|
||||
None,
|
||||
0,
|
||||
),
|
||||
)
|
||||
conns_attrs = conns_attr_func(state, rand_keys_c)
|
||||
)(state, rand_keys_c)
|
||||
|
||||
conns = conns.at[:all_conns_cnt, :2].set(conn_indices) # set conn indices
|
||||
conns = conns.at[:all_conns_cnt, 2:].set(conns_attrs) # set conn attrs
|
||||
# set conn indices
|
||||
conns = conns.at[:all_conns_cnt, :2].set(conn_indices)
|
||||
|
||||
# set conn history markers if needed
|
||||
if "historical_marker" in self.conn_gene.fixed_attrs:
|
||||
conns = conns.at[:all_conns_cnt, 2].set(conn_markers)
|
||||
|
||||
# set conn attrs
|
||||
conns = conns.at[:all_conns_cnt, len(self.conn_gene.fixed_attrs) :].set(
|
||||
conns_attrs
|
||||
)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import sympy as sp
|
||||
from .base import BaseGenome
|
||||
from .gene import DefaultNode, DefaultConn
|
||||
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
|
||||
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs
|
||||
from .utils import unflatten_conns, extract_gene_attrs, extract_gene_attrs
|
||||
|
||||
from tensorneat.common import (
|
||||
topological_sort,
|
||||
@@ -16,7 +16,7 @@ from tensorneat.common import (
|
||||
I_INF,
|
||||
attach_with_inf,
|
||||
ACT,
|
||||
AGG
|
||||
AGG,
|
||||
)
|
||||
|
||||
|
||||
@@ -73,8 +73,8 @@ class DefaultGenome(BaseGenome):
|
||||
|
||||
ini_vals = jnp.full((self.max_nodes,), jnp.nan)
|
||||
ini_vals = ini_vals.at[self.input_idx].set(inputs)
|
||||
nodes_attrs = vmap(extract_node_attrs)(nodes)
|
||||
conns_attrs = vmap(extract_conn_attrs)(conns)
|
||||
nodes_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))(self.node_gene, nodes)
|
||||
conns_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))(self.conn_gene, conns)
|
||||
|
||||
def cond_fun(carry):
|
||||
values, idx = carry
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from .base import BaseConn
|
||||
from .default import DefaultConn
|
||||
from .origin import OriginConn
|
||||
60
src/tensorneat/genome/gene/conn/origin.py
Normal file
60
src/tensorneat/genome/gene/conn/origin.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from .default import DefaultConn
|
||||
|
||||
|
||||
class OriginConn(DefaultConn):
|
||||
"""
|
||||
Implementation of connections in origin NEAT Paper.
|
||||
Details at https://github.com/EMI-Group/tensorneat/issues/11.
|
||||
"""
|
||||
|
||||
# add historical_marker into fixed_attrs
|
||||
fixed_attrs = ["input_index", "output_index", "historical_marker"]
|
||||
custom_attrs = ["weight"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def crossover(self, state, randkey, attrs1, attrs2):
|
||||
# random pick one of attrs, without attrs exchange
|
||||
return jnp.where(
|
||||
# origin code, generate multiple random numbers, without attrs exchange
|
||||
# jax.random.normal(randkey, attrs1.shape) > 0,
|
||||
jax.random.normal(randkey)
|
||||
> 0, # generate one random number, without attrs exchange
|
||||
attrs1,
|
||||
attrs2,
|
||||
)
|
||||
|
||||
def get_historical_marker(self, state, gene_array):
|
||||
return gene_array[2]
|
||||
|
||||
def repr(self, state, conn, precision=2, idx_width=3, func_width=8):
|
||||
in_idx, out_idx, historical_marker, weight = conn
|
||||
|
||||
in_idx = int(in_idx)
|
||||
out_idx = int(out_idx)
|
||||
historical_marker = int(historical_marker)
|
||||
weight = round(float(weight), precision)
|
||||
|
||||
return "{}(in: {:<{idx_width}}, out: {:<{idx_width}}, historical_marker: {:<{idx_width}}, weight: {:<{float_width}})".format(
|
||||
self.__class__.__name__,
|
||||
in_idx,
|
||||
out_idx,
|
||||
historical_marker,
|
||||
weight,
|
||||
idx_width=idx_width,
|
||||
float_width=precision + 3,
|
||||
)
|
||||
|
||||
def to_dict(self, state, conn):
|
||||
return {
|
||||
"in": int(conn[0]),
|
||||
"out": int(conn[1]),
|
||||
"historical_marker": int(conn[2]),
|
||||
"weight": jnp.float32(conn[3]),
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
from .base import BaseNode
|
||||
from .default import DefaultNode
|
||||
from .bias import BiasNode
|
||||
from .origin import OriginNode
|
||||
|
||||
27
src/tensorneat/genome/gene/node/origin.py
Normal file
27
src/tensorneat/genome/gene/node/origin.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from .default import DefaultNode
|
||||
|
||||
|
||||
class OriginNode(DefaultNode):
|
||||
"""
|
||||
Implementation of nodes in origin NEAT Paper.
|
||||
Details at https://github.com/EMI-Group/tensorneat/issues/11.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def crossover(self, state, randkey, attrs1, attrs2):
|
||||
# random pick one of attrs, without attrs exchange
|
||||
return jnp.where(
|
||||
# origin code, generate multiple random numbers, without attrs exchange
|
||||
# jax.random.normal(randkey, attrs1.shape) > 0,
|
||||
jax.random.normal(randkey)
|
||||
> 0, # generate one random number, without attrs exchange
|
||||
attrs1,
|
||||
attrs2,
|
||||
)
|
||||
@@ -2,12 +2,10 @@ import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
|
||||
from .base import BaseCrossover
|
||||
from ...utils import (
|
||||
extract_node_attrs,
|
||||
extract_conn_attrs,
|
||||
set_node_attrs,
|
||||
set_conn_attrs,
|
||||
)
|
||||
from ...utils import extract_gene_attrs, set_gene_attrs
|
||||
|
||||
from tensorneat.common import fetch_first, I_INF
|
||||
from tensorneat.genome.gene import BaseGene
|
||||
|
||||
|
||||
class DefaultCrossover(BaseCrossover):
|
||||
@@ -17,71 +15,90 @@ class DefaultCrossover(BaseCrossover):
|
||||
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||
"""
|
||||
randkey1, randkey2 = jax.random.split(randkey, 2)
|
||||
randkeys1 = jax.random.split(randkey1, genome.max_nodes)
|
||||
randkeys2 = jax.random.split(randkey2, genome.max_conns)
|
||||
node_randkeys = jax.random.split(randkey1, genome.max_nodes)
|
||||
conn_randkeys = jax.random.split(randkey2, genome.max_conns)
|
||||
batch_create_new_gene = jax.vmap(
|
||||
create_new_gene, in_axes=(None, 0, None, 0, 0, None, None)
|
||||
)
|
||||
|
||||
# crossover nodes
|
||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||
# make homologous genes align in nodes2 align with nodes1
|
||||
nodes2 = self.align_array(keys1, keys2, nodes2, is_conn=False)
|
||||
|
||||
# For not homologous genes, use the value of nodes1(winner)
|
||||
# For homologous genes, use the crossover result between nodes1 and nodes2
|
||||
node_attrs1 = vmap(extract_node_attrs)(nodes1)
|
||||
node_attrs2 = vmap(extract_node_attrs)(nodes2)
|
||||
|
||||
new_node_attrs = jnp.where(
|
||||
jnp.isnan(node_attrs1) | jnp.isnan(node_attrs2), # one of them is nan
|
||||
node_attrs1, # not homologous genes or both nan, use the value of nodes1(winner)
|
||||
vmap(genome.node_gene.crossover, in_axes=(None, 0, 0, 0))(
|
||||
state, randkeys1, node_attrs1, node_attrs2
|
||||
), # homologous or both nan
|
||||
node_keys1, node_keys2 = (
|
||||
nodes1[:, 0 : len(genome.node_gene.fixed_attrs)],
|
||||
nodes2[:, 0 : len(genome.node_gene.fixed_attrs)],
|
||||
)
|
||||
node_attrs1 = vmap(extract_gene_attrs, in_axes=(None, 0))(
|
||||
genome.node_gene, nodes1
|
||||
)
|
||||
node_attrs2 = vmap(extract_gene_attrs, in_axes=(None, 0))(
|
||||
genome.node_gene, nodes2
|
||||
)
|
||||
|
||||
new_node_attrs = batch_create_new_gene(
|
||||
state,
|
||||
node_randkeys,
|
||||
genome.node_gene,
|
||||
node_keys1,
|
||||
node_attrs1,
|
||||
node_keys2,
|
||||
node_attrs2,
|
||||
)
|
||||
new_nodes = vmap(set_gene_attrs, in_axes=(None, 0, 0))(
|
||||
genome.node_gene, nodes1, new_node_attrs
|
||||
)
|
||||
new_nodes = vmap(set_node_attrs)(nodes1, new_node_attrs)
|
||||
|
||||
# crossover connections
|
||||
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
|
||||
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
|
||||
|
||||
conns_attrs1 = vmap(extract_conn_attrs)(conns1)
|
||||
conns_attrs2 = vmap(extract_conn_attrs)(conns2)
|
||||
|
||||
new_conn_attrs = jnp.where(
|
||||
jnp.isnan(conns_attrs1) | jnp.isnan(conns_attrs2),
|
||||
conns_attrs1, # not homologous genes or both nan, use the value of conns1(winner)
|
||||
vmap(genome.conn_gene.crossover, in_axes=(None, 0, 0, 0))(
|
||||
state, randkeys2, conns_attrs1, conns_attrs2
|
||||
), # homologous or both nan
|
||||
# all fixed_attrs together will use to identify a connection
|
||||
# if using historical marker, use it
|
||||
# related to issue: https://github.com/EMI-Group/tensorneat/issues/11
|
||||
conn_keys1, conn_keys2 = (
|
||||
conns1[:, 0 : len(genome.conn_gene.fixed_attrs)],
|
||||
conns2[:, 0 : len(genome.conn_gene.fixed_attrs)],
|
||||
)
|
||||
conn_attrs1 = vmap(extract_gene_attrs, in_axes=(None, 0))(
|
||||
genome.conn_gene, conns1
|
||||
)
|
||||
conn_attrs2 = vmap(extract_gene_attrs, in_axes=(None, 0))(
|
||||
genome.conn_gene, conns2
|
||||
)
|
||||
|
||||
new_conn_attrs = batch_create_new_gene(
|
||||
state,
|
||||
conn_randkeys,
|
||||
genome.conn_gene,
|
||||
conn_keys1,
|
||||
conn_attrs1,
|
||||
conn_keys2,
|
||||
conn_attrs2,
|
||||
)
|
||||
new_conns = vmap(set_gene_attrs, in_axes=(None, 0, 0))(
|
||||
genome.conn_gene, conns1, new_conn_attrs
|
||||
)
|
||||
new_conns = vmap(set_conn_attrs)(conns1, new_conn_attrs)
|
||||
|
||||
return new_nodes, new_conns
|
||||
|
||||
def align_array(self, seq1, seq2, ar2, is_conn: bool):
|
||||
"""
|
||||
After I review this code, I found that it is the most difficult part of the code.
|
||||
Please consider carefully before change it!
|
||||
make ar2 align with ar1.
|
||||
:param seq1:
|
||||
:param seq2:
|
||||
:param ar2:
|
||||
:param is_conn:
|
||||
:return:
|
||||
align means to intersect part of ar2 will be at the same position as ar1,
|
||||
non-intersect part of ar2 will be set to Nan
|
||||
"""
|
||||
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
|
||||
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
|
||||
|
||||
if is_conn:
|
||||
mask = jnp.all(mask, axis=2)
|
||||
def create_new_gene(
|
||||
state,
|
||||
randkey,
|
||||
gene: BaseGene,
|
||||
gene_key,
|
||||
gene_attrs,
|
||||
genes_keys,
|
||||
genes_attrs,
|
||||
):
|
||||
# find homologous genes
|
||||
homologous_idx = fetch_first(jnp.all(gene_key == genes_keys, axis=1))
|
||||
|
||||
intersect_mask = mask.any(axis=1)
|
||||
idx = jnp.arange(0, len(seq1))
|
||||
idx_fixed = jnp.dot(mask, idx)
|
||||
def none(): # no homologous, use winner's gene
|
||||
return gene_attrs
|
||||
|
||||
refactor_ar2 = jnp.where(
|
||||
intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan
|
||||
)
|
||||
def crossover(): # when homologous gene is found, execute crossover
|
||||
return gene.crossover(state, randkey, gene_attrs, genes_attrs[homologous_idx])
|
||||
|
||||
return refactor_ar2
|
||||
new_attrs = jax.lax.cond(
|
||||
homologous_idx == I_INF, # homologous gene is not found or current gene is nan
|
||||
none,
|
||||
crossover,
|
||||
)
|
||||
|
||||
return new_attrs
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from jax import vmap, numpy as jnp
|
||||
|
||||
from .base import BaseDistance
|
||||
from ...utils import extract_node_attrs, extract_conn_attrs
|
||||
from ...gene import BaseGene
|
||||
from ...utils import extract_gene_attrs
|
||||
|
||||
|
||||
class DefaultDistance(BaseDistance):
|
||||
@@ -17,83 +18,47 @@ class DefaultDistance(BaseDistance):
|
||||
"""
|
||||
The distance between two genomes
|
||||
"""
|
||||
d = self.node_distance(state, genome, nodes1, nodes2) + self.conn_distance(
|
||||
state, genome, conns1, conns2
|
||||
)
|
||||
return d
|
||||
node_distance = self.gene_distance(state, genome.node_gene, nodes1, nodes2)
|
||||
conn_distance = self.gene_distance(state, genome.conn_gene, conns1, conns2)
|
||||
return node_distance + conn_distance
|
||||
|
||||
def node_distance(self, state, genome, nodes1, nodes2):
|
||||
|
||||
def gene_distance(self, state, gene: BaseGene, genes1, genes2):
|
||||
"""
|
||||
The distance of the nodes part for two genomes
|
||||
The distance between to genes
|
||||
genes1: 2-D jax array with shape
|
||||
genes2: 2-D jax array with shape
|
||||
gene1.shape == gene2.shape
|
||||
"""
|
||||
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
|
||||
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
|
||||
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
|
||||
cnt1 = jnp.sum(~jnp.isnan(genes1[:, 0]))
|
||||
cnt2 = jnp.sum(~jnp.isnan(genes2[:, 0]))
|
||||
max_cnt = jnp.maximum(cnt1, cnt2)
|
||||
|
||||
# align homologous nodes
|
||||
# this process is similar to np.intersect1d.
|
||||
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
|
||||
keys = nodes[:, 0]
|
||||
sorted_indices = jnp.argsort(keys, axis=0)
|
||||
nodes = nodes[sorted_indices]
|
||||
nodes = jnp.concatenate(
|
||||
[nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0
|
||||
# this process is similar to np.intersect1d in higher dimension
|
||||
total_genes = jnp.concatenate((genes1, genes2), axis=0)
|
||||
identifiers = total_genes[:, : len(gene.fixed_attrs)]
|
||||
sorted_identifiers = jnp.lexsort(identifiers.T[::-1])
|
||||
total_genes = total_genes[sorted_identifiers]
|
||||
total_genes = jnp.concatenate(
|
||||
[total_genes, jnp.full((1, total_genes.shape[1]), jnp.nan)], axis=0
|
||||
) # add a nan row to the end
|
||||
fr, sr = nodes[:-1], nodes[1:] # first row, second row
|
||||
fr, sr = total_genes[:-1], total_genes[1:] # first row, second row
|
||||
|
||||
# flag location of homologous nodes
|
||||
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
|
||||
# intersect part of two genes
|
||||
intersect_mask = jnp.all(
|
||||
fr[:, : len(gene.fixed_attrs)] == sr[:, : len(gene.fixed_attrs)], axis=1
|
||||
) & ~jnp.isnan(fr[:, 0])
|
||||
|
||||
# calculate the count of non_homologous of two genomes
|
||||
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
non_homologous_cnt = cnt1 + cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
|
||||
# calculate the distance of homologous nodes
|
||||
fr_attrs = vmap(extract_node_attrs)(fr)
|
||||
sr_attrs = vmap(extract_node_attrs)(sr)
|
||||
hnd = vmap(genome.node_gene.distance, in_axes=(None, 0, 0))(
|
||||
state, fr_attrs, sr_attrs
|
||||
) # homologous node distance
|
||||
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
|
||||
homologous_distance = jnp.sum(hnd * intersect_mask)
|
||||
|
||||
val = (
|
||||
non_homologous_cnt * self.compatibility_disjoint
|
||||
+ homologous_distance * self.compatibility_weight
|
||||
)
|
||||
|
||||
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
|
||||
|
||||
return val
|
||||
|
||||
def conn_distance(self, state, genome, conns1, conns2):
|
||||
"""
|
||||
The distance of the conns part for two genomes
|
||||
"""
|
||||
con_cnt1 = jnp.sum(~jnp.isnan(conns1[:, 0]))
|
||||
con_cnt2 = jnp.sum(~jnp.isnan(conns2[:, 0]))
|
||||
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
|
||||
|
||||
cons = jnp.concatenate((conns1, conns2), axis=0)
|
||||
keys = cons[:, :2]
|
||||
sorted_indices = jnp.lexsort(keys.T[::-1])
|
||||
cons = cons[sorted_indices]
|
||||
cons = jnp.concatenate(
|
||||
[cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0
|
||||
) # add a nan row to the end
|
||||
fr, sr = cons[:-1], cons[1:] # first row, second row
|
||||
|
||||
# both genome has such connection
|
||||
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
|
||||
|
||||
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
|
||||
fr_attrs = vmap(extract_conn_attrs)(fr)
|
||||
sr_attrs = vmap(extract_conn_attrs)(sr)
|
||||
hcd = vmap(genome.conn_gene.distance, in_axes=(None, 0, 0))(
|
||||
state, fr_attrs, sr_attrs
|
||||
) # homologous connection distance
|
||||
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
|
||||
homologous_distance = jnp.sum(hcd * intersect_mask)
|
||||
fr_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))(gene, fr)
|
||||
sr_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))(gene, sr)
|
||||
|
||||
# homologous gene distance
|
||||
hgd = vmap(gene.distance, in_axes=(None, 0, 0))(state, fr_attrs, sr_attrs)
|
||||
hgd = jnp.where(jnp.isnan(hgd), 0, hgd)
|
||||
homologous_distance = jnp.sum(hgd * intersect_mask)
|
||||
|
||||
val = (
|
||||
non_homologous_cnt * self.compatibility_disjoint
|
||||
|
||||
@@ -3,5 +3,5 @@ from tensorneat.common import StatefulBaseClass, State
|
||||
|
||||
class BaseMutation(StatefulBaseClass):
|
||||
|
||||
def __call__(self, state, genome, randkey, nodes, conns, new_node_key):
|
||||
def __call__(self, state, genome, randkey, nodes, conns, new_node_key, new_conn_key):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -13,10 +13,8 @@ from ...utils import (
|
||||
add_conn,
|
||||
delete_node_by_pos,
|
||||
delete_conn_by_pos,
|
||||
extract_node_attrs,
|
||||
extract_conn_attrs,
|
||||
set_node_attrs,
|
||||
set_conn_attrs,
|
||||
extract_gene_attrs,
|
||||
set_gene_attrs
|
||||
)
|
||||
|
||||
|
||||
@@ -33,17 +31,28 @@ class DefaultMutation(BaseMutation):
|
||||
self.node_add = node_add
|
||||
self.node_delete = node_delete
|
||||
|
||||
def __call__(self, state, genome, randkey, nodes, conns, new_node_key):
|
||||
def __call__(
|
||||
self, state, genome, randkey, nodes, conns, new_node_key, new_conn_key
|
||||
):
|
||||
assert (
|
||||
new_node_key.shape == ()
|
||||
) # scalar, as there is max one new node in each mutation
|
||||
assert new_conn_key.shape == (
|
||||
3,
|
||||
) # there are max 3 new connections (mutate add node + mutate add conn)
|
||||
|
||||
k1, k2 = jax.random.split(randkey)
|
||||
|
||||
nodes, conns = self.mutate_structure(
|
||||
state, genome, k1, nodes, conns, new_node_key
|
||||
state, genome, k1, nodes, conns, new_node_key, new_conn_key
|
||||
)
|
||||
nodes, conns = self.mutate_values(state, genome, k2, nodes, conns)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
def mutate_structure(self, state, genome, randkey, nodes, conns, new_node_key):
|
||||
def mutate_structure(
|
||||
self, state, genome, randkey, nodes, conns, new_node_key, new_conn_key
|
||||
):
|
||||
def mutate_add_node(key_, nodes_, conns_):
|
||||
"""
|
||||
add a node while do not influence the output of the network
|
||||
@@ -57,27 +66,33 @@ class DefaultMutation(BaseMutation):
|
||||
|
||||
def successful_add_node():
|
||||
# remove the original connection and record its attrs
|
||||
original_attrs = extract_conn_attrs(conns_[idx])
|
||||
original_attrs = extract_gene_attrs(genome.conn_gene, conns_[idx])
|
||||
new_conns = delete_conn_by_pos(conns_, idx)
|
||||
|
||||
# add a new node with identity attrs
|
||||
new_nodes = add_node(
|
||||
nodes_, new_node_key, genome.node_gene.new_identity_attrs(state)
|
||||
nodes_, jnp.array([new_node_key]), genome.node_gene.new_identity_attrs(state)
|
||||
)
|
||||
|
||||
# whether to use historical marker in connection
|
||||
if "historical_marker" in genome.conn_gene.fixed_attrs:
|
||||
fix_attrs1 = jnp.array([i_key, new_node_key, new_conn_key[0]])
|
||||
fix_attrs2 = jnp.array([new_node_key, o_key, new_conn_key[1]])
|
||||
else:
|
||||
fix_attrs1 = jnp.array([i_key, new_node_key])
|
||||
fix_attrs2 = jnp.array([new_node_key, o_key])
|
||||
|
||||
# add two new connections
|
||||
# first is with identity attrs
|
||||
new_conns = add_conn(
|
||||
new_conns,
|
||||
i_key,
|
||||
new_node_key,
|
||||
fix_attrs1,
|
||||
genome.conn_gene.new_identity_attrs(state),
|
||||
)
|
||||
# second is with the origin attrs
|
||||
new_conns = add_conn(
|
||||
new_conns,
|
||||
new_node_key,
|
||||
o_key,
|
||||
fix_attrs2,
|
||||
original_attrs,
|
||||
)
|
||||
|
||||
@@ -160,8 +175,12 @@ class DefaultMutation(BaseMutation):
|
||||
|
||||
def successful():
|
||||
# add a connection with zero attrs
|
||||
if "historical_marker" in genome.conn_gene.fixed_attrs:
|
||||
new_fix_attrs = jnp.array([i_key, o_key, new_conn_key[2]])
|
||||
else:
|
||||
new_fix_attrs = jnp.array([i_key, o_key])
|
||||
return nodes_, add_conn(
|
||||
conns_, i_key, o_key, genome.conn_gene.new_zero_attrs(state)
|
||||
conns_, new_fix_attrs, genome.conn_gene.new_zero_attrs(state)
|
||||
)
|
||||
|
||||
if genome.network_type == "feedforward":
|
||||
@@ -228,17 +247,25 @@ class DefaultMutation(BaseMutation):
|
||||
nodes_randkeys = jax.random.split(k1, num=genome.max_nodes)
|
||||
conns_randkeys = jax.random.split(k2, num=genome.max_conns)
|
||||
|
||||
node_attrs = vmap(extract_node_attrs)(nodes)
|
||||
node_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))(
|
||||
genome.node_gene, nodes
|
||||
)
|
||||
new_node_attrs = vmap(genome.node_gene.mutate, in_axes=(None, 0, 0))(
|
||||
state, nodes_randkeys, node_attrs
|
||||
)
|
||||
new_nodes = vmap(set_node_attrs)(nodes, new_node_attrs)
|
||||
new_nodes = vmap(set_gene_attrs, in_axes=(None, 0, 0))(
|
||||
genome.node_gene, nodes, new_node_attrs
|
||||
)
|
||||
|
||||
conn_attrs = vmap(extract_conn_attrs)(conns)
|
||||
conn_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))(
|
||||
genome.conn_gene, conns
|
||||
)
|
||||
new_conn_attrs = vmap(genome.conn_gene.mutate, in_axes=(None, 0, 0))(
|
||||
state, conns_randkeys, conn_attrs
|
||||
)
|
||||
new_conns = vmap(set_conn_attrs)(conns, new_conn_attrs)
|
||||
new_conns = vmap(set_gene_attrs, in_axes=(None, 0, 0))(
|
||||
genome.conn_gene, conns, new_conn_attrs
|
||||
)
|
||||
|
||||
# nan nodes not changed
|
||||
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)
|
||||
|
||||
@@ -5,7 +5,7 @@ from .utils import unflatten_conns
|
||||
from .base import BaseGenome
|
||||
from .gene import DefaultNode, DefaultConn
|
||||
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
|
||||
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs
|
||||
from .utils import unflatten_conns, extract_gene_attrs, extract_gene_attrs
|
||||
|
||||
from tensorneat.common import attach_with_inf
|
||||
|
||||
@@ -55,8 +55,8 @@ class RecurrentGenome(BaseGenome):
|
||||
|
||||
vals = jnp.full((self.max_nodes,), jnp.nan)
|
||||
|
||||
nodes_attrs = vmap(extract_node_attrs)(nodes)
|
||||
conns_attrs = vmap(extract_conn_attrs)(conns)
|
||||
nodes_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))(self.node_gene, nodes)
|
||||
conns_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))(self.conn_gene, conns)
|
||||
expand_conns_attrs = attach_with_inf(conns_attrs, u_conns)
|
||||
|
||||
def body_func(_, values):
|
||||
|
||||
@@ -2,6 +2,7 @@ import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from .gene import BaseGene
|
||||
from tensorneat.common import fetch_first, I_INF
|
||||
|
||||
|
||||
@@ -38,49 +39,27 @@ def valid_cnt(nodes_or_conns):
|
||||
return jnp.sum(~jnp.isnan(nodes_or_conns[:, 0]))
|
||||
|
||||
|
||||
def extract_node_attrs(node):
|
||||
def extract_gene_attrs(gene: BaseGene, gene_array):
|
||||
"""
|
||||
node: Array(NL, )
|
||||
extract the attributes of a node
|
||||
extract the custom attributes of the gene
|
||||
"""
|
||||
return node[1:] # 0 is for idx
|
||||
return gene_array[len(gene.fixed_attrs) :]
|
||||
|
||||
|
||||
def set_node_attrs(node, attrs):
|
||||
def set_gene_attrs(gene: BaseGene, gene_array, attrs):
|
||||
"""
|
||||
node: Array(NL, )
|
||||
attrs: Array(NL-1, )
|
||||
set the attributes of a node
|
||||
set the custom attributes of the gene
|
||||
"""
|
||||
return node.at[1:].set(attrs) # 0 is for idx
|
||||
return gene_array.at[len(gene.fixed_attrs) :].set(attrs)
|
||||
|
||||
|
||||
def extract_conn_attrs(conn):
|
||||
"""
|
||||
conn: Array(CL, )
|
||||
extract the attributes of a connection
|
||||
"""
|
||||
return conn[2:] # 0, 1 is for in-idx and out-idx
|
||||
|
||||
|
||||
def set_conn_attrs(conn, attrs):
|
||||
"""
|
||||
conn: Array(CL, )
|
||||
attrs: Array(CL-2, )
|
||||
set the attributes of a connection
|
||||
"""
|
||||
return conn.at[2:].set(attrs) # 0, 1 is for in-idx and out-idx
|
||||
|
||||
|
||||
def add_node(nodes, new_key: int, attrs):
|
||||
def add_node(nodes, fix_attrs, custom_attrs):
|
||||
"""
|
||||
Add a new node to the genome.
|
||||
The new node will place at the first NaN row.
|
||||
"""
|
||||
exist_keys = nodes[:, 0]
|
||||
pos = fetch_first(jnp.isnan(exist_keys))
|
||||
new_nodes = nodes.at[pos, 0].set(new_key)
|
||||
return new_nodes.at[pos, 1:].set(attrs)
|
||||
pos = fetch_first(jnp.isnan(nodes[:, 0]))
|
||||
return nodes.at[pos].set(jnp.concatenate((fix_attrs, custom_attrs)))
|
||||
|
||||
|
||||
def delete_node_by_pos(nodes, pos):
|
||||
@@ -91,15 +70,13 @@ def delete_node_by_pos(nodes, pos):
|
||||
return nodes.at[pos].set(jnp.nan)
|
||||
|
||||
|
||||
def add_conn(conns, i_key, o_key, attrs):
|
||||
def add_conn(conns, fix_attrs, custom_attrs):
|
||||
"""
|
||||
Add a new connection to the genome.
|
||||
The new connection will place at the first NaN row.
|
||||
"""
|
||||
con_keys = conns[:, 0]
|
||||
pos = fetch_first(jnp.isnan(con_keys))
|
||||
new_conns = conns.at[pos, 0:2].set(jnp.array([i_key, o_key]))
|
||||
return new_conns.at[pos, 2:].set(attrs)
|
||||
pos = fetch_first(jnp.isnan(conns[:, 0]))
|
||||
return conns.at[pos].set(jnp.concatenate((fix_attrs, custom_attrs)))
|
||||
|
||||
|
||||
def delete_conn_by_pos(conns, pos):
|
||||
|
||||
247
test/origin_operations_test.py
Normal file
247
test/origin_operations_test.py
Normal file
@@ -0,0 +1,247 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from tensorneat.genome.operations import (
|
||||
DefaultMutation,
|
||||
DefaultDistance,
|
||||
DefaultCrossover,
|
||||
)
|
||||
from tensorneat.genome import (
|
||||
DefaultGenome,
|
||||
DefaultNode,
|
||||
DefaultConn,
|
||||
OriginNode,
|
||||
OriginConn,
|
||||
)
|
||||
from tensorneat.genome.utils import add_node, add_conn
|
||||
|
||||
origin_genome = DefaultGenome(
|
||||
node_gene=OriginNode(response_init_std=1),
|
||||
conn_gene=OriginConn(),
|
||||
mutation=DefaultMutation(conn_add=1, node_add=1, conn_delete=0, node_delete=0),
|
||||
crossover=DefaultCrossover(),
|
||||
distance=DefaultDistance(),
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
max_nodes=6,
|
||||
max_conns=6,
|
||||
)
|
||||
|
||||
default_genome = DefaultGenome(
|
||||
node_gene=DefaultNode(response_init_std=1),
|
||||
conn_gene=DefaultConn(),
|
||||
mutation=DefaultMutation(conn_add=1, node_add=1, conn_delete=0, node_delete=0),
|
||||
crossover=DefaultCrossover(),
|
||||
distance=DefaultDistance(),
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
max_nodes=6,
|
||||
max_conns=6,
|
||||
)
|
||||
|
||||
state = default_genome.setup()
|
||||
state = origin_genome.setup(state)
|
||||
|
||||
randkey = jax.random.PRNGKey(42)
|
||||
|
||||
|
||||
def mutation_default():
|
||||
nodes, conns = default_genome.initialize(state, randkey)
|
||||
print("old genome:\n", default_genome.repr(state, nodes, conns))
|
||||
|
||||
nodes, conns = default_genome.execute_mutation(
|
||||
state,
|
||||
randkey,
|
||||
nodes,
|
||||
conns,
|
||||
new_node_key=jnp.asarray(10),
|
||||
new_conn_keys=jnp.array([20, 21, 22]),
|
||||
)
|
||||
|
||||
# new_conn_keys is not used in default genome
|
||||
print("new genome:\n", default_genome.repr(state, nodes, conns))
|
||||
|
||||
|
||||
def mutation_origin():
|
||||
nodes, conns = origin_genome.initialize(state, randkey)
|
||||
print(conns)
|
||||
print("old genome:\n", origin_genome.repr(state, nodes, conns))
|
||||
|
||||
nodes, conns = origin_genome.execute_mutation(
|
||||
state,
|
||||
randkey,
|
||||
nodes,
|
||||
conns,
|
||||
new_node_key=jnp.asarray(10),
|
||||
new_conn_keys=jnp.array([20, 21, 22]),
|
||||
)
|
||||
print(conns)
|
||||
# new_conn_keys is used in origin genome
|
||||
print("new genome:\n", origin_genome.repr(state, nodes, conns))
|
||||
|
||||
def distance_default():
|
||||
nodes, conns = default_genome.initialize(state, randkey)
|
||||
nodes = add_node(
|
||||
nodes,
|
||||
fix_attrs=jnp.asarray([10]),
|
||||
custom_attrs=default_genome.node_gene.new_identity_attrs(state)
|
||||
)
|
||||
conns1 = add_conn(
|
||||
conns,
|
||||
fix_attrs=jnp.array([0, 10]), # in-idx, out-idx
|
||||
custom_attrs=default_genome.conn_gene.new_zero_attrs(state)
|
||||
)
|
||||
conns2 = add_conn(
|
||||
conns,
|
||||
fix_attrs=jnp.array([0, 10]), # in-idx, out-idx
|
||||
custom_attrs=default_genome.conn_gene.new_random_attrs(state, randkey)
|
||||
)
|
||||
print("genome1:\n", default_genome.repr(state, nodes, conns1))
|
||||
print("genome2:\n", default_genome.repr(state, nodes, conns2))
|
||||
|
||||
distance = default_genome.execute_distance(state, nodes, conns1, nodes, conns2)
|
||||
print("distance: ", distance)
|
||||
|
||||
def distance_origin_case1():
|
||||
"""
|
||||
distance with different historical marker
|
||||
"""
|
||||
nodes, conns = origin_genome.initialize(state, randkey)
|
||||
nodes = add_node(
|
||||
nodes,
|
||||
fix_attrs=jnp.asarray([10]),
|
||||
custom_attrs=origin_genome.node_gene.new_identity_attrs(state)
|
||||
)
|
||||
conns1 = add_conn(
|
||||
conns,
|
||||
fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark
|
||||
custom_attrs=origin_genome.conn_gene.new_zero_attrs(state)
|
||||
)
|
||||
conns2 = add_conn(
|
||||
conns,
|
||||
fix_attrs=jnp.array([0, 10, 88]), # in-idx, out-idx, historical mark
|
||||
custom_attrs=origin_genome.conn_gene.new_random_attrs(state, randkey)
|
||||
)
|
||||
print("genome1:\n", origin_genome.repr(state, nodes, conns1))
|
||||
print("genome2:\n", origin_genome.repr(state, nodes, conns2))
|
||||
|
||||
distance = origin_genome.execute_distance(state, nodes, conns1, nodes, conns2)
|
||||
print("distance: ", distance)
|
||||
|
||||
def distance_origin_case2():
|
||||
"""
|
||||
distance with same historical marker
|
||||
"""
|
||||
nodes, conns = origin_genome.initialize(state, randkey)
|
||||
nodes = add_node(
|
||||
nodes,
|
||||
fix_attrs=jnp.asarray([10]),
|
||||
custom_attrs=origin_genome.node_gene.new_identity_attrs(state)
|
||||
)
|
||||
conns1 = add_conn(
|
||||
conns,
|
||||
fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark
|
||||
custom_attrs=origin_genome.conn_gene.new_zero_attrs(state)
|
||||
)
|
||||
conns2 = add_conn(
|
||||
conns,
|
||||
fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark
|
||||
custom_attrs=origin_genome.conn_gene.new_random_attrs(state, randkey)
|
||||
)
|
||||
print("genome1:\n", origin_genome.repr(state, nodes, conns1))
|
||||
print("genome2:\n", origin_genome.repr(state, nodes, conns2))
|
||||
|
||||
distance = origin_genome.execute_distance(state, nodes, conns1, nodes, conns2)
|
||||
print("distance: ", distance)
|
||||
|
||||
def crossover_origin_case1():
|
||||
"""
|
||||
crossover with different historical marker
|
||||
"""
|
||||
nodes, conns = origin_genome.initialize(state, randkey)
|
||||
nodes = add_node(
|
||||
nodes,
|
||||
fix_attrs=jnp.asarray([10]),
|
||||
custom_attrs=origin_genome.node_gene.new_identity_attrs(state)
|
||||
)
|
||||
conns1 = add_conn(
|
||||
conns,
|
||||
fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark
|
||||
custom_attrs=origin_genome.conn_gene.new_zero_attrs(state)
|
||||
)
|
||||
conns2 = add_conn(
|
||||
conns,
|
||||
fix_attrs=jnp.array([0, 10, 88]), # in-idx, out-idx, historical mark
|
||||
custom_attrs=origin_genome.conn_gene.new_random_attrs(state, randkey)
|
||||
)
|
||||
print("genome1:\n", origin_genome.repr(state, nodes, conns1))
|
||||
print("genome2:\n", origin_genome.repr(state, nodes, conns2))
|
||||
|
||||
# (0, 10)'s weight must be 0 (disjoint gene, use fitter)
|
||||
child_nodes, child_conns = origin_genome.execute_crossover(state, randkey, nodes, conns1, nodes, conns2)
|
||||
print("child:\n", origin_genome.repr(state, child_nodes, child_conns))
|
||||
|
||||
def crossover_origin_case2():
|
||||
"""
|
||||
crossover with same historical marker
|
||||
"""
|
||||
nodes, conns = origin_genome.initialize(state, randkey)
|
||||
nodes = add_node(
|
||||
nodes,
|
||||
fix_attrs=jnp.asarray([10]),
|
||||
custom_attrs=origin_genome.node_gene.new_identity_attrs(state)
|
||||
)
|
||||
conns1 = add_conn(
|
||||
conns,
|
||||
fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark
|
||||
custom_attrs=origin_genome.conn_gene.new_zero_attrs(state)
|
||||
)
|
||||
conns2 = add_conn(
|
||||
conns,
|
||||
fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark
|
||||
custom_attrs=origin_genome.conn_gene.new_random_attrs(state, randkey)
|
||||
)
|
||||
print("genome1:\n", origin_genome.repr(state, nodes, conns1))
|
||||
print("genome2:\n", origin_genome.repr(state, nodes, conns2))
|
||||
|
||||
# (0, 10)'s weight might be random or zero (homologous gene)
|
||||
|
||||
# zero case:
|
||||
child_nodes, child_conns = origin_genome.execute_crossover(state, jax.random.key(99), nodes, conns1, nodes, conns2)
|
||||
print("child_zero:\n", origin_genome.repr(state, child_nodes, child_conns))
|
||||
|
||||
# random case:
|
||||
child_nodes, child_conns = origin_genome.execute_crossover(state, jax.random.key(0), nodes, conns1, nodes, conns2)
|
||||
print("child_random:\n", origin_genome.repr(state, child_nodes, child_conns))
|
||||
|
||||
def crossover_origin_case3():
|
||||
"""
|
||||
test examine it use random gene rather than attribute exchange
|
||||
"""
|
||||
nodes, conns = origin_genome.initialize(state, randkey)
|
||||
nodes1 = add_node(
|
||||
nodes,
|
||||
fix_attrs=jnp.asarray([10]),
|
||||
custom_attrs=jnp.array([1, 2, 0, 0])
|
||||
)
|
||||
nodes2 = add_node(
|
||||
nodes,
|
||||
fix_attrs=jnp.asarray([10]),
|
||||
custom_attrs=jnp.array([100, 200, 0, 0])
|
||||
)
|
||||
|
||||
# [1, 2] case
|
||||
child_nodes, child_conns = origin_genome.execute_crossover(state, jax.random.key(99), nodes1, conns, nodes2, conns)
|
||||
print("child1:\n", origin_genome.repr(state, child_nodes, child_conns))
|
||||
|
||||
# [100, 200] case
|
||||
child_nodes, child_conns = origin_genome.execute_crossover(state, jax.random.key(1), nodes1, conns, nodes2, conns)
|
||||
print("child2:\n", origin_genome.repr(state, child_nodes, child_conns))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# mutation_origin()
|
||||
# distance_default()
|
||||
# distance_origin_case1()
|
||||
# distance_origin_case2()
|
||||
# crossover_origin_case1()
|
||||
# crossover_origin_case2()
|
||||
crossover_origin_case3()
|
||||
Reference in New Issue
Block a user