diff --git a/examples/brax/hopper_origin.py b/examples/brax/hopper_origin.py new file mode 100644 index 0000000..475bb8f --- /dev/null +++ b/examples/brax/hopper_origin.py @@ -0,0 +1,49 @@ +from tensorneat.pipeline import Pipeline +from tensorneat.algorithm.neat import NEAT +from tensorneat.genome import DefaultGenome, OriginNode, OriginConn + +from tensorneat.problem.rl import BraxEnv +from tensorneat.common import ACT, AGG + +""" +Solving Hopper with OriginGene +See https://github.com/EMI-Group/tensorneat/issues/11 +""" + +if __name__ == "__main__": + pipeline = Pipeline( + algorithm=NEAT( + pop_size=1000, + species_size=20, + survival_threshold=0.1, + compatibility_threshold=1.0, + genome=DefaultGenome( + num_inputs=11, + num_outputs=3, + init_hidden_layers=(), + # origin node gene, which use the same crossover behavior to the origin NEAT paper + node_gene=OriginNode( + activation_options=ACT.tanh, + aggregation_options=AGG.sum, + response_lower_bound = 1, + response_upper_bound= 1, # fix response to 1 + ), + # use origin connection, which using historical marker + conn_gene=OriginConn(), + output_transform=ACT.tanh, + ), + ), + problem=BraxEnv( + env_name="hopper", + max_step=1000, + ), + seed=42, + generation_limit=100, + fitness_target=5000, + ) + + # initialize state + state = pipeline.setup() + # print(state) + # run until terminate + state, best = pipeline.auto_run(state) diff --git a/examples/func_fit/xor.py b/examples/func_fit/xor.py index 9c85a9a..9537e49 100644 --- a/examples/func_fit/xor.py +++ b/examples/func_fit/xor.py @@ -30,7 +30,8 @@ pipeline.show(state, best) # visualize the best individual network = algorithm.genome.network_dict(state, *best) -algorithm.genome.visualize(network, save_path="./imgs/xor_network.svg") +print(algorithm.genome.repr(state, *best)) +# algorithm.genome.visualize(network, save_path="./imgs/xor_network.svg") # transform the best individual to latex formula from tensorneat.common.sympy_tools import to_latex_code, to_python_code diff --git a/examples/func_fit/xor_origin.py b/examples/func_fit/xor_origin.py new file mode 100644 index 0000000..c0aa3bc --- /dev/null +++ b/examples/func_fit/xor_origin.py @@ -0,0 +1,55 @@ +from tensorneat.pipeline import Pipeline +from tensorneat import algorithm, genome, problem +from tensorneat.genome import OriginNode, OriginConn +from tensorneat.common import ACT + +""" +Solving XOR-3d problem with OriginGene +See https://github.com/EMI-Group/tensorneat/issues/11 +""" + +algorithm = algorithm.NEAT( + pop_size=10000, + species_size=20, + survival_threshold=0.01, + genome=genome.DefaultGenome( + node_gene=OriginNode(), + conn_gene=OriginConn(), + num_inputs=3, + num_outputs=1, + max_nodes=7, + output_transform=ACT.sigmoid, + ), +) +problem = problem.XOR3d() + +pipeline = Pipeline( + algorithm, + problem, + generation_limit=200, + fitness_target=-1e-6, + seed=42, +) +state = pipeline.setup() +# run until terminate +state, best = pipeline.auto_run(state) +# show result +pipeline.show(state, best) + +# visualize the best individual +network = algorithm.genome.network_dict(state, *best) +print(algorithm.genome.repr(state, *best)) +# algorithm.genome.visualize(network, save_path="./imgs/xor_network.svg") + +# transform the best individual to latex formula +from tensorneat.common.sympy_tools import to_latex_code, to_python_code + +sympy_res = algorithm.genome.sympy_func( + state, network, sympy_output_transform=ACT.obtain_sympy(ACT.sigmoid) +) +latex_code = to_latex_code(*sympy_res) +print(latex_code) + +# transform the best individual to python code +python_code = to_python_code(*sympy_res) +print(python_code) diff --git a/src/tensorneat/algorithm/hyperneat/substrate/default.py b/src/tensorneat/algorithm/hyperneat/substrate/default.py index f3e02b3..96d797e 100644 --- a/src/tensorneat/algorithm/hyperneat/substrate/default.py +++ b/src/tensorneat/algorithm/hyperneat/substrate/default.py @@ -2,7 +2,7 @@ from jax import vmap import numpy as np from .base import BaseSubstrate -from tensorneat.genome.utils import set_conn_attrs +from tensorneat.genome.utils import set_gene_attrs class DefaultSubstrate(BaseSubstrate): @@ -21,7 +21,8 @@ class DefaultSubstrate(BaseSubstrate): def make_conns(self, query_res): # change weight of conns - return vmap(set_conn_attrs)(self.conns, query_res) + # the last column is the weight + return self.conns.at[:, -1].set(query_res) @property def query_coors(self): diff --git a/src/tensorneat/algorithm/neat/neat.py b/src/tensorneat/algorithm/neat/neat.py index 0487e29..fdb926f 100644 --- a/src/tensorneat/algorithm/neat/neat.py +++ b/src/tensorneat/algorithm/neat/neat.py @@ -116,6 +116,25 @@ class NEAT(BaseAlgorithm): next_node_key = max_node_key + 1 new_node_keys = jnp.arange(self.pop_size) + next_node_key + # find next conn historical markers for mutation if needed + if "historical_marker" in self.genome.conn_gene.fixed_attrs: + all_conns_markers = vmap( + self.genome.conn_gene.get_historical_marker, in_axes=(None, 0) + )(state, state.pop_conns) + + max_conn_markers = jnp.max( + all_conns_markers, where=~jnp.isnan(all_conns_markers), initial=0 + ) + next_conn_markers = max_conn_markers + 1 + new_conn_markers = ( + jnp.arange(self.pop_size * 3).reshape(self.pop_size, 3) + + next_conn_markers + ) + else: + # no need to generate new conn historical markers + # use 0 + new_conn_markers = jnp.full((self.pop_size, 3), 0) + # prepare random keys k1, k2, randkey = jax.random.split(state.randkey, 3) crossover_randkeys = jax.random.split(k1, self.pop_size) @@ -133,9 +152,9 @@ class NEAT(BaseAlgorithm): # batch mutation m_n_nodes, m_n_conns = vmap( - self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0) + self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0, 0) )( - state, mutate_randkeys, n_nodes, n_conns, new_node_keys + state, mutate_randkeys, n_nodes, n_conns, new_node_keys, new_conn_markers ) # mutated_new_nodes, mutated_new_conns # elitism don't mutate diff --git a/src/tensorneat/genome/base.py b/src/tensorneat/genome/base.py index d5b0be2..b6f1a42 100644 --- a/src/tensorneat/genome/base.py +++ b/src/tensorneat/genome/base.py @@ -113,8 +113,12 @@ class BaseGenome(StatefulBaseClass): def visualize(self): raise NotImplementedError - def execute_mutation(self, state, randkey, nodes, conns, new_node_key): - return self.mutation(state, self, randkey, nodes, conns, new_node_key) + def execute_mutation( + self, state, randkey, nodes, conns, new_node_key, new_conn_keys + ): + return self.mutation( + state, self, randkey, nodes, conns, new_node_key, new_conn_keys + ) def execute_crossover(self, state, randkey, nodes1, conns1, nodes2, conns2): return self.crossover(state, self, randkey, nodes1, conns1, nodes2, conns2) @@ -144,19 +148,31 @@ class BaseGenome(StatefulBaseClass): conns = jnp.full((self.max_conns, self.conn_gene.length), jnp.nan) # create input and output indices conn_indices = self.all_init_conns + + # create connection initial history markers + conn_markers = jnp.arange(all_conns_cnt) + # create conn attrs rand_keys_c = jax.random.split(k2, num=all_conns_cnt) - conns_attr_func = jax.vmap( + conns_attrs = jax.vmap( self.conn_gene.new_random_attrs, in_axes=( None, 0, ), - ) - conns_attrs = conns_attr_func(state, rand_keys_c) + )(state, rand_keys_c) - conns = conns.at[:all_conns_cnt, :2].set(conn_indices) # set conn indices - conns = conns.at[:all_conns_cnt, 2:].set(conns_attrs) # set conn attrs + # set conn indices + conns = conns.at[:all_conns_cnt, :2].set(conn_indices) + + # set conn history markers if needed + if "historical_marker" in self.conn_gene.fixed_attrs: + conns = conns.at[:all_conns_cnt, 2].set(conn_markers) + + # set conn attrs + conns = conns.at[:all_conns_cnt, len(self.conn_gene.fixed_attrs) :].set( + conns_attrs + ) return nodes, conns diff --git a/src/tensorneat/genome/default.py b/src/tensorneat/genome/default.py index eb9d1f3..66a9e64 100644 --- a/src/tensorneat/genome/default.py +++ b/src/tensorneat/genome/default.py @@ -8,7 +8,7 @@ import sympy as sp from .base import BaseGenome from .gene import DefaultNode, DefaultConn from .operations import DefaultMutation, DefaultCrossover, DefaultDistance -from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs +from .utils import unflatten_conns, extract_gene_attrs, extract_gene_attrs from tensorneat.common import ( topological_sort, @@ -16,7 +16,7 @@ from tensorneat.common import ( I_INF, attach_with_inf, ACT, - AGG + AGG, ) @@ -73,8 +73,8 @@ class DefaultGenome(BaseGenome): ini_vals = jnp.full((self.max_nodes,), jnp.nan) ini_vals = ini_vals.at[self.input_idx].set(inputs) - nodes_attrs = vmap(extract_node_attrs)(nodes) - conns_attrs = vmap(extract_conn_attrs)(conns) + nodes_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))(self.node_gene, nodes) + conns_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))(self.conn_gene, conns) def cond_fun(carry): values, idx = carry diff --git a/src/tensorneat/genome/gene/conn/__init__.py b/src/tensorneat/genome/gene/conn/__init__.py index 222174a..948eb3f 100644 --- a/src/tensorneat/genome/gene/conn/__init__.py +++ b/src/tensorneat/genome/gene/conn/__init__.py @@ -1,2 +1,3 @@ from .base import BaseConn from .default import DefaultConn +from .origin import OriginConn \ No newline at end of file diff --git a/src/tensorneat/genome/gene/conn/origin.py b/src/tensorneat/genome/gene/conn/origin.py new file mode 100644 index 0000000..0137c10 --- /dev/null +++ b/src/tensorneat/genome/gene/conn/origin.py @@ -0,0 +1,60 @@ +import jax, jax.numpy as jnp +from .default import DefaultConn + + +class OriginConn(DefaultConn): + """ + Implementation of connections in origin NEAT Paper. + Details at https://github.com/EMI-Group/tensorneat/issues/11. + """ + + # add historical_marker into fixed_attrs + fixed_attrs = ["input_index", "output_index", "historical_marker"] + custom_attrs = ["weight"] + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + def crossover(self, state, randkey, attrs1, attrs2): + # random pick one of attrs, without attrs exchange + return jnp.where( + # origin code, generate multiple random numbers, without attrs exchange + # jax.random.normal(randkey, attrs1.shape) > 0, + jax.random.normal(randkey) + > 0, # generate one random number, without attrs exchange + attrs1, + attrs2, + ) + + def get_historical_marker(self, state, gene_array): + return gene_array[2] + + def repr(self, state, conn, precision=2, idx_width=3, func_width=8): + in_idx, out_idx, historical_marker, weight = conn + + in_idx = int(in_idx) + out_idx = int(out_idx) + historical_marker = int(historical_marker) + weight = round(float(weight), precision) + + return "{}(in: {:<{idx_width}}, out: {:<{idx_width}}, historical_marker: {:<{idx_width}}, weight: {:<{float_width}})".format( + self.__class__.__name__, + in_idx, + out_idx, + historical_marker, + weight, + idx_width=idx_width, + float_width=precision + 3, + ) + + def to_dict(self, state, conn): + return { + "in": int(conn[0]), + "out": int(conn[1]), + "historical_marker": int(conn[2]), + "weight": jnp.float32(conn[3]), + } \ No newline at end of file diff --git a/src/tensorneat/genome/gene/node/__init__.py b/src/tensorneat/genome/gene/node/__init__.py index 2ae1afa..6b22d0f 100644 --- a/src/tensorneat/genome/gene/node/__init__.py +++ b/src/tensorneat/genome/gene/node/__init__.py @@ -1,3 +1,4 @@ from .base import BaseNode from .default import DefaultNode from .bias import BiasNode +from .origin import OriginNode diff --git a/src/tensorneat/genome/gene/node/origin.py b/src/tensorneat/genome/gene/node/origin.py new file mode 100644 index 0000000..41d614f --- /dev/null +++ b/src/tensorneat/genome/gene/node/origin.py @@ -0,0 +1,27 @@ +import jax, jax.numpy as jnp +from .default import DefaultNode + + +class OriginNode(DefaultNode): + """ + Implementation of nodes in origin NEAT Paper. + Details at https://github.com/EMI-Group/tensorneat/issues/11. + """ + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + def crossover(self, state, randkey, attrs1, attrs2): + # random pick one of attrs, without attrs exchange + return jnp.where( + # origin code, generate multiple random numbers, without attrs exchange + # jax.random.normal(randkey, attrs1.shape) > 0, + jax.random.normal(randkey) + > 0, # generate one random number, without attrs exchange + attrs1, + attrs2, + ) diff --git a/src/tensorneat/genome/operations/crossover/default.py b/src/tensorneat/genome/operations/crossover/default.py index d17a0f4..34ab9d1 100644 --- a/src/tensorneat/genome/operations/crossover/default.py +++ b/src/tensorneat/genome/operations/crossover/default.py @@ -2,12 +2,10 @@ import jax from jax import vmap, numpy as jnp from .base import BaseCrossover -from ...utils import ( - extract_node_attrs, - extract_conn_attrs, - set_node_attrs, - set_conn_attrs, -) +from ...utils import extract_gene_attrs, set_gene_attrs + +from tensorneat.common import fetch_first, I_INF +from tensorneat.genome.gene import BaseGene class DefaultCrossover(BaseCrossover): @@ -17,71 +15,90 @@ class DefaultCrossover(BaseCrossover): notice that genome1 should have higher fitness than genome2 (genome1 is winner!) """ randkey1, randkey2 = jax.random.split(randkey, 2) - randkeys1 = jax.random.split(randkey1, genome.max_nodes) - randkeys2 = jax.random.split(randkey2, genome.max_conns) + node_randkeys = jax.random.split(randkey1, genome.max_nodes) + conn_randkeys = jax.random.split(randkey2, genome.max_conns) + batch_create_new_gene = jax.vmap( + create_new_gene, in_axes=(None, 0, None, 0, 0, None, None) + ) # crossover nodes - keys1, keys2 = nodes1[:, 0], nodes2[:, 0] - # make homologous genes align in nodes2 align with nodes1 - nodes2 = self.align_array(keys1, keys2, nodes2, is_conn=False) - - # For not homologous genes, use the value of nodes1(winner) - # For homologous genes, use the crossover result between nodes1 and nodes2 - node_attrs1 = vmap(extract_node_attrs)(nodes1) - node_attrs2 = vmap(extract_node_attrs)(nodes2) - - new_node_attrs = jnp.where( - jnp.isnan(node_attrs1) | jnp.isnan(node_attrs2), # one of them is nan - node_attrs1, # not homologous genes or both nan, use the value of nodes1(winner) - vmap(genome.node_gene.crossover, in_axes=(None, 0, 0, 0))( - state, randkeys1, node_attrs1, node_attrs2 - ), # homologous or both nan + node_keys1, node_keys2 = ( + nodes1[:, 0 : len(genome.node_gene.fixed_attrs)], + nodes2[:, 0 : len(genome.node_gene.fixed_attrs)], + ) + node_attrs1 = vmap(extract_gene_attrs, in_axes=(None, 0))( + genome.node_gene, nodes1 + ) + node_attrs2 = vmap(extract_gene_attrs, in_axes=(None, 0))( + genome.node_gene, nodes2 + ) + + new_node_attrs = batch_create_new_gene( + state, + node_randkeys, + genome.node_gene, + node_keys1, + node_attrs1, + node_keys2, + node_attrs2, + ) + new_nodes = vmap(set_gene_attrs, in_axes=(None, 0, 0))( + genome.node_gene, nodes1, new_node_attrs ) - new_nodes = vmap(set_node_attrs)(nodes1, new_node_attrs) # crossover connections - con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2] - conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True) - - conns_attrs1 = vmap(extract_conn_attrs)(conns1) - conns_attrs2 = vmap(extract_conn_attrs)(conns2) - - new_conn_attrs = jnp.where( - jnp.isnan(conns_attrs1) | jnp.isnan(conns_attrs2), - conns_attrs1, # not homologous genes or both nan, use the value of conns1(winner) - vmap(genome.conn_gene.crossover, in_axes=(None, 0, 0, 0))( - state, randkeys2, conns_attrs1, conns_attrs2 - ), # homologous or both nan + # all fixed_attrs together will use to identify a connection + # if using historical marker, use it + # related to issue: https://github.com/EMI-Group/tensorneat/issues/11 + conn_keys1, conn_keys2 = ( + conns1[:, 0 : len(genome.conn_gene.fixed_attrs)], + conns2[:, 0 : len(genome.conn_gene.fixed_attrs)], + ) + conn_attrs1 = vmap(extract_gene_attrs, in_axes=(None, 0))( + genome.conn_gene, conns1 + ) + conn_attrs2 = vmap(extract_gene_attrs, in_axes=(None, 0))( + genome.conn_gene, conns2 + ) + + new_conn_attrs = batch_create_new_gene( + state, + conn_randkeys, + genome.conn_gene, + conn_keys1, + conn_attrs1, + conn_keys2, + conn_attrs2, + ) + new_conns = vmap(set_gene_attrs, in_axes=(None, 0, 0))( + genome.conn_gene, conns1, new_conn_attrs ) - new_conns = vmap(set_conn_attrs)(conns1, new_conn_attrs) return new_nodes, new_conns - def align_array(self, seq1, seq2, ar2, is_conn: bool): - """ - After I review this code, I found that it is the most difficult part of the code. - Please consider carefully before change it! - make ar2 align with ar1. - :param seq1: - :param seq2: - :param ar2: - :param is_conn: - :return: - align means to intersect part of ar2 will be at the same position as ar1, - non-intersect part of ar2 will be set to Nan - """ - seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :] - mask = (seq1 == seq2) & (~jnp.isnan(seq1)) - if is_conn: - mask = jnp.all(mask, axis=2) +def create_new_gene( + state, + randkey, + gene: BaseGene, + gene_key, + gene_attrs, + genes_keys, + genes_attrs, +): + # find homologous genes + homologous_idx = fetch_first(jnp.all(gene_key == genes_keys, axis=1)) - intersect_mask = mask.any(axis=1) - idx = jnp.arange(0, len(seq1)) - idx_fixed = jnp.dot(mask, idx) + def none(): # no homologous, use winner's gene + return gene_attrs - refactor_ar2 = jnp.where( - intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan - ) + def crossover(): # when homologous gene is found, execute crossover + return gene.crossover(state, randkey, gene_attrs, genes_attrs[homologous_idx]) - return refactor_ar2 + new_attrs = jax.lax.cond( + homologous_idx == I_INF, # homologous gene is not found or current gene is nan + none, + crossover, + ) + + return new_attrs diff --git a/src/tensorneat/genome/operations/distance/default.py b/src/tensorneat/genome/operations/distance/default.py index 00d7284..c18acc2 100644 --- a/src/tensorneat/genome/operations/distance/default.py +++ b/src/tensorneat/genome/operations/distance/default.py @@ -1,7 +1,8 @@ from jax import vmap, numpy as jnp from .base import BaseDistance -from ...utils import extract_node_attrs, extract_conn_attrs +from ...gene import BaseGene +from ...utils import extract_gene_attrs class DefaultDistance(BaseDistance): @@ -17,83 +18,47 @@ class DefaultDistance(BaseDistance): """ The distance between two genomes """ - d = self.node_distance(state, genome, nodes1, nodes2) + self.conn_distance( - state, genome, conns1, conns2 - ) - return d + node_distance = self.gene_distance(state, genome.node_gene, nodes1, nodes2) + conn_distance = self.gene_distance(state, genome.conn_gene, conns1, conns2) + return node_distance + conn_distance - def node_distance(self, state, genome, nodes1, nodes2): + + def gene_distance(self, state, gene: BaseGene, genes1, genes2): """ - The distance of the nodes part for two genomes + The distance between to genes + genes1: 2-D jax array with shape + genes2: 2-D jax array with shape + gene1.shape == gene2.shape """ - node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0])) - node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0])) - max_cnt = jnp.maximum(node_cnt1, node_cnt2) + cnt1 = jnp.sum(~jnp.isnan(genes1[:, 0])) + cnt2 = jnp.sum(~jnp.isnan(genes2[:, 0])) + max_cnt = jnp.maximum(cnt1, cnt2) # align homologous nodes - # this process is similar to np.intersect1d. - nodes = jnp.concatenate((nodes1, nodes2), axis=0) - keys = nodes[:, 0] - sorted_indices = jnp.argsort(keys, axis=0) - nodes = nodes[sorted_indices] - nodes = jnp.concatenate( - [nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0 + # this process is similar to np.intersect1d in higher dimension + total_genes = jnp.concatenate((genes1, genes2), axis=0) + identifiers = total_genes[:, : len(gene.fixed_attrs)] + sorted_identifiers = jnp.lexsort(identifiers.T[::-1]) + total_genes = total_genes[sorted_identifiers] + total_genes = jnp.concatenate( + [total_genes, jnp.full((1, total_genes.shape[1]), jnp.nan)], axis=0 ) # add a nan row to the end - fr, sr = nodes[:-1], nodes[1:] # first row, second row + fr, sr = total_genes[:-1], total_genes[1:] # first row, second row - # flag location of homologous nodes - intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0]) + # intersect part of two genes + intersect_mask = jnp.all( + fr[:, : len(gene.fixed_attrs)] == sr[:, : len(gene.fixed_attrs)], axis=1 + ) & ~jnp.isnan(fr[:, 0]) - # calculate the count of non_homologous of two genomes - non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask) + non_homologous_cnt = cnt1 + cnt2 - 2 * jnp.sum(intersect_mask) - # calculate the distance of homologous nodes - fr_attrs = vmap(extract_node_attrs)(fr) - sr_attrs = vmap(extract_node_attrs)(sr) - hnd = vmap(genome.node_gene.distance, in_axes=(None, 0, 0))( - state, fr_attrs, sr_attrs - ) # homologous node distance - hnd = jnp.where(jnp.isnan(hnd), 0, hnd) - homologous_distance = jnp.sum(hnd * intersect_mask) - - val = ( - non_homologous_cnt * self.compatibility_disjoint - + homologous_distance * self.compatibility_weight - ) - - val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize - - return val - - def conn_distance(self, state, genome, conns1, conns2): - """ - The distance of the conns part for two genomes - """ - con_cnt1 = jnp.sum(~jnp.isnan(conns1[:, 0])) - con_cnt2 = jnp.sum(~jnp.isnan(conns2[:, 0])) - max_cnt = jnp.maximum(con_cnt1, con_cnt2) - - cons = jnp.concatenate((conns1, conns2), axis=0) - keys = cons[:, :2] - sorted_indices = jnp.lexsort(keys.T[::-1]) - cons = cons[sorted_indices] - cons = jnp.concatenate( - [cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0 - ) # add a nan row to the end - fr, sr = cons[:-1], cons[1:] # first row, second row - - # both genome has such connection - intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0]) - - non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask) - - fr_attrs = vmap(extract_conn_attrs)(fr) - sr_attrs = vmap(extract_conn_attrs)(sr) - hcd = vmap(genome.conn_gene.distance, in_axes=(None, 0, 0))( - state, fr_attrs, sr_attrs - ) # homologous connection distance - hcd = jnp.where(jnp.isnan(hcd), 0, hcd) - homologous_distance = jnp.sum(hcd * intersect_mask) + fr_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))(gene, fr) + sr_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))(gene, sr) + + # homologous gene distance + hgd = vmap(gene.distance, in_axes=(None, 0, 0))(state, fr_attrs, sr_attrs) + hgd = jnp.where(jnp.isnan(hgd), 0, hgd) + homologous_distance = jnp.sum(hgd * intersect_mask) val = ( non_homologous_cnt * self.compatibility_disjoint diff --git a/src/tensorneat/genome/operations/mutation/base.py b/src/tensorneat/genome/operations/mutation/base.py index a1b08cb..94e6edf 100644 --- a/src/tensorneat/genome/operations/mutation/base.py +++ b/src/tensorneat/genome/operations/mutation/base.py @@ -3,5 +3,5 @@ from tensorneat.common import StatefulBaseClass, State class BaseMutation(StatefulBaseClass): - def __call__(self, state, genome, randkey, nodes, conns, new_node_key): + def __call__(self, state, genome, randkey, nodes, conns, new_node_key, new_conn_key): raise NotImplementedError diff --git a/src/tensorneat/genome/operations/mutation/default.py b/src/tensorneat/genome/operations/mutation/default.py index d761e1a..06ef5c6 100644 --- a/src/tensorneat/genome/operations/mutation/default.py +++ b/src/tensorneat/genome/operations/mutation/default.py @@ -13,10 +13,8 @@ from ...utils import ( add_conn, delete_node_by_pos, delete_conn_by_pos, - extract_node_attrs, - extract_conn_attrs, - set_node_attrs, - set_conn_attrs, + extract_gene_attrs, + set_gene_attrs ) @@ -33,17 +31,28 @@ class DefaultMutation(BaseMutation): self.node_add = node_add self.node_delete = node_delete - def __call__(self, state, genome, randkey, nodes, conns, new_node_key): + def __call__( + self, state, genome, randkey, nodes, conns, new_node_key, new_conn_key + ): + assert ( + new_node_key.shape == () + ) # scalar, as there is max one new node in each mutation + assert new_conn_key.shape == ( + 3, + ) # there are max 3 new connections (mutate add node + mutate add conn) + k1, k2 = jax.random.split(randkey) nodes, conns = self.mutate_structure( - state, genome, k1, nodes, conns, new_node_key + state, genome, k1, nodes, conns, new_node_key, new_conn_key ) nodes, conns = self.mutate_values(state, genome, k2, nodes, conns) return nodes, conns - def mutate_structure(self, state, genome, randkey, nodes, conns, new_node_key): + def mutate_structure( + self, state, genome, randkey, nodes, conns, new_node_key, new_conn_key + ): def mutate_add_node(key_, nodes_, conns_): """ add a node while do not influence the output of the network @@ -57,27 +66,33 @@ class DefaultMutation(BaseMutation): def successful_add_node(): # remove the original connection and record its attrs - original_attrs = extract_conn_attrs(conns_[idx]) + original_attrs = extract_gene_attrs(genome.conn_gene, conns_[idx]) new_conns = delete_conn_by_pos(conns_, idx) # add a new node with identity attrs new_nodes = add_node( - nodes_, new_node_key, genome.node_gene.new_identity_attrs(state) + nodes_, jnp.array([new_node_key]), genome.node_gene.new_identity_attrs(state) ) + # whether to use historical marker in connection + if "historical_marker" in genome.conn_gene.fixed_attrs: + fix_attrs1 = jnp.array([i_key, new_node_key, new_conn_key[0]]) + fix_attrs2 = jnp.array([new_node_key, o_key, new_conn_key[1]]) + else: + fix_attrs1 = jnp.array([i_key, new_node_key]) + fix_attrs2 = jnp.array([new_node_key, o_key]) + # add two new connections # first is with identity attrs new_conns = add_conn( new_conns, - i_key, - new_node_key, + fix_attrs1, genome.conn_gene.new_identity_attrs(state), ) # second is with the origin attrs new_conns = add_conn( new_conns, - new_node_key, - o_key, + fix_attrs2, original_attrs, ) @@ -160,8 +175,12 @@ class DefaultMutation(BaseMutation): def successful(): # add a connection with zero attrs + if "historical_marker" in genome.conn_gene.fixed_attrs: + new_fix_attrs = jnp.array([i_key, o_key, new_conn_key[2]]) + else: + new_fix_attrs = jnp.array([i_key, o_key]) return nodes_, add_conn( - conns_, i_key, o_key, genome.conn_gene.new_zero_attrs(state) + conns_, new_fix_attrs, genome.conn_gene.new_zero_attrs(state) ) if genome.network_type == "feedforward": @@ -228,17 +247,25 @@ class DefaultMutation(BaseMutation): nodes_randkeys = jax.random.split(k1, num=genome.max_nodes) conns_randkeys = jax.random.split(k2, num=genome.max_conns) - node_attrs = vmap(extract_node_attrs)(nodes) + node_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))( + genome.node_gene, nodes + ) new_node_attrs = vmap(genome.node_gene.mutate, in_axes=(None, 0, 0))( state, nodes_randkeys, node_attrs ) - new_nodes = vmap(set_node_attrs)(nodes, new_node_attrs) + new_nodes = vmap(set_gene_attrs, in_axes=(None, 0, 0))( + genome.node_gene, nodes, new_node_attrs + ) - conn_attrs = vmap(extract_conn_attrs)(conns) + conn_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))( + genome.conn_gene, conns + ) new_conn_attrs = vmap(genome.conn_gene.mutate, in_axes=(None, 0, 0))( state, conns_randkeys, conn_attrs ) - new_conns = vmap(set_conn_attrs)(conns, new_conn_attrs) + new_conns = vmap(set_gene_attrs, in_axes=(None, 0, 0))( + genome.conn_gene, conns, new_conn_attrs + ) # nan nodes not changed new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes) diff --git a/src/tensorneat/genome/recurrent.py b/src/tensorneat/genome/recurrent.py index 0317369..cc16b74 100644 --- a/src/tensorneat/genome/recurrent.py +++ b/src/tensorneat/genome/recurrent.py @@ -5,7 +5,7 @@ from .utils import unflatten_conns from .base import BaseGenome from .gene import DefaultNode, DefaultConn from .operations import DefaultMutation, DefaultCrossover, DefaultDistance -from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs +from .utils import unflatten_conns, extract_gene_attrs, extract_gene_attrs from tensorneat.common import attach_with_inf @@ -55,8 +55,8 @@ class RecurrentGenome(BaseGenome): vals = jnp.full((self.max_nodes,), jnp.nan) - nodes_attrs = vmap(extract_node_attrs)(nodes) - conns_attrs = vmap(extract_conn_attrs)(conns) + nodes_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))(self.node_gene, nodes) + conns_attrs = vmap(extract_gene_attrs, in_axes=(None, 0))(self.conn_gene, conns) expand_conns_attrs = attach_with_inf(conns_attrs, u_conns) def body_func(_, values): diff --git a/src/tensorneat/genome/utils.py b/src/tensorneat/genome/utils.py index 903bf85..dc67c3d 100644 --- a/src/tensorneat/genome/utils.py +++ b/src/tensorneat/genome/utils.py @@ -2,6 +2,7 @@ import jax from jax import vmap, numpy as jnp import numpy as np +from .gene import BaseGene from tensorneat.common import fetch_first, I_INF @@ -38,49 +39,27 @@ def valid_cnt(nodes_or_conns): return jnp.sum(~jnp.isnan(nodes_or_conns[:, 0])) -def extract_node_attrs(node): +def extract_gene_attrs(gene: BaseGene, gene_array): """ - node: Array(NL, ) - extract the attributes of a node + extract the custom attributes of the gene """ - return node[1:] # 0 is for idx + return gene_array[len(gene.fixed_attrs) :] -def set_node_attrs(node, attrs): +def set_gene_attrs(gene: BaseGene, gene_array, attrs): """ - node: Array(NL, ) - attrs: Array(NL-1, ) - set the attributes of a node + set the custom attributes of the gene """ - return node.at[1:].set(attrs) # 0 is for idx + return gene_array.at[len(gene.fixed_attrs) :].set(attrs) -def extract_conn_attrs(conn): - """ - conn: Array(CL, ) - extract the attributes of a connection - """ - return conn[2:] # 0, 1 is for in-idx and out-idx - - -def set_conn_attrs(conn, attrs): - """ - conn: Array(CL, ) - attrs: Array(CL-2, ) - set the attributes of a connection - """ - return conn.at[2:].set(attrs) # 0, 1 is for in-idx and out-idx - - -def add_node(nodes, new_key: int, attrs): +def add_node(nodes, fix_attrs, custom_attrs): """ Add a new node to the genome. The new node will place at the first NaN row. """ - exist_keys = nodes[:, 0] - pos = fetch_first(jnp.isnan(exist_keys)) - new_nodes = nodes.at[pos, 0].set(new_key) - return new_nodes.at[pos, 1:].set(attrs) + pos = fetch_first(jnp.isnan(nodes[:, 0])) + return nodes.at[pos].set(jnp.concatenate((fix_attrs, custom_attrs))) def delete_node_by_pos(nodes, pos): @@ -91,15 +70,13 @@ def delete_node_by_pos(nodes, pos): return nodes.at[pos].set(jnp.nan) -def add_conn(conns, i_key, o_key, attrs): +def add_conn(conns, fix_attrs, custom_attrs): """ Add a new connection to the genome. The new connection will place at the first NaN row. """ - con_keys = conns[:, 0] - pos = fetch_first(jnp.isnan(con_keys)) - new_conns = conns.at[pos, 0:2].set(jnp.array([i_key, o_key])) - return new_conns.at[pos, 2:].set(attrs) + pos = fetch_first(jnp.isnan(conns[:, 0])) + return conns.at[pos].set(jnp.concatenate((fix_attrs, custom_attrs))) def delete_conn_by_pos(conns, pos): diff --git a/test/origin_operations_test.py b/test/origin_operations_test.py new file mode 100644 index 0000000..31a0dda --- /dev/null +++ b/test/origin_operations_test.py @@ -0,0 +1,247 @@ +import jax, jax.numpy as jnp +from tensorneat.genome.operations import ( + DefaultMutation, + DefaultDistance, + DefaultCrossover, +) +from tensorneat.genome import ( + DefaultGenome, + DefaultNode, + DefaultConn, + OriginNode, + OriginConn, +) +from tensorneat.genome.utils import add_node, add_conn + +origin_genome = DefaultGenome( + node_gene=OriginNode(response_init_std=1), + conn_gene=OriginConn(), + mutation=DefaultMutation(conn_add=1, node_add=1, conn_delete=0, node_delete=0), + crossover=DefaultCrossover(), + distance=DefaultDistance(), + num_inputs=3, + num_outputs=1, + max_nodes=6, + max_conns=6, +) + +default_genome = DefaultGenome( + node_gene=DefaultNode(response_init_std=1), + conn_gene=DefaultConn(), + mutation=DefaultMutation(conn_add=1, node_add=1, conn_delete=0, node_delete=0), + crossover=DefaultCrossover(), + distance=DefaultDistance(), + num_inputs=3, + num_outputs=1, + max_nodes=6, + max_conns=6, +) + +state = default_genome.setup() +state = origin_genome.setup(state) + +randkey = jax.random.PRNGKey(42) + + +def mutation_default(): + nodes, conns = default_genome.initialize(state, randkey) + print("old genome:\n", default_genome.repr(state, nodes, conns)) + + nodes, conns = default_genome.execute_mutation( + state, + randkey, + nodes, + conns, + new_node_key=jnp.asarray(10), + new_conn_keys=jnp.array([20, 21, 22]), + ) + + # new_conn_keys is not used in default genome + print("new genome:\n", default_genome.repr(state, nodes, conns)) + + +def mutation_origin(): + nodes, conns = origin_genome.initialize(state, randkey) + print(conns) + print("old genome:\n", origin_genome.repr(state, nodes, conns)) + + nodes, conns = origin_genome.execute_mutation( + state, + randkey, + nodes, + conns, + new_node_key=jnp.asarray(10), + new_conn_keys=jnp.array([20, 21, 22]), + ) + print(conns) + # new_conn_keys is used in origin genome + print("new genome:\n", origin_genome.repr(state, nodes, conns)) + +def distance_default(): + nodes, conns = default_genome.initialize(state, randkey) + nodes = add_node( + nodes, + fix_attrs=jnp.asarray([10]), + custom_attrs=default_genome.node_gene.new_identity_attrs(state) + ) + conns1 = add_conn( + conns, + fix_attrs=jnp.array([0, 10]), # in-idx, out-idx + custom_attrs=default_genome.conn_gene.new_zero_attrs(state) + ) + conns2 = add_conn( + conns, + fix_attrs=jnp.array([0, 10]), # in-idx, out-idx + custom_attrs=default_genome.conn_gene.new_random_attrs(state, randkey) + ) + print("genome1:\n", default_genome.repr(state, nodes, conns1)) + print("genome2:\n", default_genome.repr(state, nodes, conns2)) + + distance = default_genome.execute_distance(state, nodes, conns1, nodes, conns2) + print("distance: ", distance) + +def distance_origin_case1(): + """ + distance with different historical marker + """ + nodes, conns = origin_genome.initialize(state, randkey) + nodes = add_node( + nodes, + fix_attrs=jnp.asarray([10]), + custom_attrs=origin_genome.node_gene.new_identity_attrs(state) + ) + conns1 = add_conn( + conns, + fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark + custom_attrs=origin_genome.conn_gene.new_zero_attrs(state) + ) + conns2 = add_conn( + conns, + fix_attrs=jnp.array([0, 10, 88]), # in-idx, out-idx, historical mark + custom_attrs=origin_genome.conn_gene.new_random_attrs(state, randkey) + ) + print("genome1:\n", origin_genome.repr(state, nodes, conns1)) + print("genome2:\n", origin_genome.repr(state, nodes, conns2)) + + distance = origin_genome.execute_distance(state, nodes, conns1, nodes, conns2) + print("distance: ", distance) + +def distance_origin_case2(): + """ + distance with same historical marker + """ + nodes, conns = origin_genome.initialize(state, randkey) + nodes = add_node( + nodes, + fix_attrs=jnp.asarray([10]), + custom_attrs=origin_genome.node_gene.new_identity_attrs(state) + ) + conns1 = add_conn( + conns, + fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark + custom_attrs=origin_genome.conn_gene.new_zero_attrs(state) + ) + conns2 = add_conn( + conns, + fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark + custom_attrs=origin_genome.conn_gene.new_random_attrs(state, randkey) + ) + print("genome1:\n", origin_genome.repr(state, nodes, conns1)) + print("genome2:\n", origin_genome.repr(state, nodes, conns2)) + + distance = origin_genome.execute_distance(state, nodes, conns1, nodes, conns2) + print("distance: ", distance) + +def crossover_origin_case1(): + """ + crossover with different historical marker + """ + nodes, conns = origin_genome.initialize(state, randkey) + nodes = add_node( + nodes, + fix_attrs=jnp.asarray([10]), + custom_attrs=origin_genome.node_gene.new_identity_attrs(state) + ) + conns1 = add_conn( + conns, + fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark + custom_attrs=origin_genome.conn_gene.new_zero_attrs(state) + ) + conns2 = add_conn( + conns, + fix_attrs=jnp.array([0, 10, 88]), # in-idx, out-idx, historical mark + custom_attrs=origin_genome.conn_gene.new_random_attrs(state, randkey) + ) + print("genome1:\n", origin_genome.repr(state, nodes, conns1)) + print("genome2:\n", origin_genome.repr(state, nodes, conns2)) + + # (0, 10)'s weight must be 0 (disjoint gene, use fitter) + child_nodes, child_conns = origin_genome.execute_crossover(state, randkey, nodes, conns1, nodes, conns2) + print("child:\n", origin_genome.repr(state, child_nodes, child_conns)) + +def crossover_origin_case2(): + """ + crossover with same historical marker + """ + nodes, conns = origin_genome.initialize(state, randkey) + nodes = add_node( + nodes, + fix_attrs=jnp.asarray([10]), + custom_attrs=origin_genome.node_gene.new_identity_attrs(state) + ) + conns1 = add_conn( + conns, + fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark + custom_attrs=origin_genome.conn_gene.new_zero_attrs(state) + ) + conns2 = add_conn( + conns, + fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark + custom_attrs=origin_genome.conn_gene.new_random_attrs(state, randkey) + ) + print("genome1:\n", origin_genome.repr(state, nodes, conns1)) + print("genome2:\n", origin_genome.repr(state, nodes, conns2)) + + # (0, 10)'s weight might be random or zero (homologous gene) + + # zero case: + child_nodes, child_conns = origin_genome.execute_crossover(state, jax.random.key(99), nodes, conns1, nodes, conns2) + print("child_zero:\n", origin_genome.repr(state, child_nodes, child_conns)) + + # random case: + child_nodes, child_conns = origin_genome.execute_crossover(state, jax.random.key(0), nodes, conns1, nodes, conns2) + print("child_random:\n", origin_genome.repr(state, child_nodes, child_conns)) + +def crossover_origin_case3(): + """ + test examine it use random gene rather than attribute exchange + """ + nodes, conns = origin_genome.initialize(state, randkey) + nodes1 = add_node( + nodes, + fix_attrs=jnp.asarray([10]), + custom_attrs=jnp.array([1, 2, 0, 0]) + ) + nodes2 = add_node( + nodes, + fix_attrs=jnp.asarray([10]), + custom_attrs=jnp.array([100, 200, 0, 0]) + ) + + # [1, 2] case + child_nodes, child_conns = origin_genome.execute_crossover(state, jax.random.key(99), nodes1, conns, nodes2, conns) + print("child1:\n", origin_genome.repr(state, child_nodes, child_conns)) + + # [100, 200] case + child_nodes, child_conns = origin_genome.execute_crossover(state, jax.random.key(1), nodes1, conns, nodes2, conns) + print("child2:\n", origin_genome.repr(state, child_nodes, child_conns)) + + +if __name__ == "__main__": + # mutation_origin() + # distance_default() + # distance_origin_case1() + # distance_origin_case2() + # crossover_origin_case1() + # crossover_origin_case2() + crossover_origin_case3()