From ec48447670a7f7a356ce768db079a404aea8568e Mon Sep 17 00:00:00 2001 From: anthonyrawlins Date: Sat, 28 Feb 2026 10:40:54 +1100 Subject: [PATCH] Add MEND module duplication operator Implement the core MEND mechanism: a module duplication mutation that copies a hidden node with all its incoming and outgoing connections. This is the single addition MEND makes to standard NEAT. - ModuleDuplication: JAX-compatible operator using jax.lax.scan - CombinedMutation: composes DefaultMutation + ModuleDuplication - DefaultGenome: accepts duplication_rate parameter - Tests for standalone duplication, combined mutation, and rate=0 no-op Co-Authored-By: Claude Opus 4.6 --- src/tensorneat/genome/default.py | 26 +-- src/tensorneat/genome/operations/__init__.py | 2 +- .../genome/operations/mutation/__init__.py | 1 + .../operations/mutation/module_duplication.py | 141 +++++++++++++++ test/test_module_duplication.py | 164 ++++++++++++++++++ 5 files changed, 322 insertions(+), 12 deletions(-) create mode 100644 src/tensorneat/genome/operations/mutation/module_duplication.py create mode 100644 test/test_module_duplication.py diff --git a/src/tensorneat/genome/default.py b/src/tensorneat/genome/default.py index 948ce32..d78c7c6 100644 --- a/src/tensorneat/genome/default.py +++ b/src/tensorneat/genome/default.py @@ -7,7 +7,12 @@ import sympy as sp from .base import BaseGenome from .gene import DefaultNode, DefaultConn -from .operations import DefaultMutation, DefaultCrossover, DefaultDistance +from .operations import ( + DefaultMutation, + DefaultCrossover, + DefaultDistance, + CombinedMutation, +) from .utils import unflatten_conns, extract_gene_attrs, extract_gene_attrs from tensorneat.common import ( @@ -40,8 +45,12 @@ class DefaultGenome(BaseGenome): output_transform=None, input_transform=None, init_hidden_layers=(), + duplication_rate: float = 0.0, ): - + # If a duplication_rate is specified, wrap the default mutation with the CombinedMutation + if duplication_rate > 0.0: + # Use CombinedMutation which includes standard mutation plus module duplication + mutation = CombinedMutation(duplication_rate=duplication_rate) super().__init__( num_inputs, num_outputs, @@ -66,7 +75,6 @@ class DefaultGenome(BaseGenome): return seqs, nodes, conns, u_conns def forward(self, state, transformed, inputs): - if self.input_transform is not None: inputs = self.input_transform(inputs) @@ -133,9 +141,7 @@ class DefaultGenome(BaseGenome): network["topo_order"] = topo_order network["topo_layers"] = topo_layers network["useful_nodes"] = find_useful_nodes( - set(network["nodes"]), - set(network["conns"]), - set(self.output_idx) + set(network["nodes"]), set(network["conns"]), set(self.output_idx) ) return network @@ -147,7 +153,6 @@ class DefaultGenome(BaseGenome): sympy_output_transform=None, backend="jax", ): - assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'" if sympy_input_transform is None and self.input_transform is not None: @@ -187,7 +192,6 @@ class DefaultGenome(BaseGenome): nodes_exprs = {} args_symbols = {} for i in order: - if i in input_idx: nodes_exprs[symbols[-i - 1]] = symbols[ -i - 1 @@ -258,7 +262,7 @@ class DefaultGenome(BaseGenome): output_exprs, forward_func, network["topo_order"], - network["useful_nodes"] + network["useful_nodes"], ) def visualize( @@ -298,7 +302,7 @@ class DefaultGenome(BaseGenome): nodes_y = [] for node in nodes: node_y = 0 - for y, last_node in enumerate(topo_layers[layer-1]): + for y, last_node in enumerate(topo_layers[layer - 1]): if (last_node, node) in conns_list: node_y += y nodes_y.append(node_y) @@ -395,7 +399,7 @@ class DefaultGenome(BaseGenome): edgecolors=edgecolors, arrowstyle=arrowstyle, arrowsize=arrowsize, - edge_color=edge_color + edge_color=edge_color, ) plt.savefig(save_path, dpi=save_dpi) plt.close() diff --git a/src/tensorneat/genome/operations/__init__.py b/src/tensorneat/genome/operations/__init__.py index a7abd68..be4d59d 100644 --- a/src/tensorneat/genome/operations/__init__.py +++ b/src/tensorneat/genome/operations/__init__.py @@ -1,3 +1,3 @@ from .crossover import BaseCrossover, DefaultCrossover -from .mutation import BaseMutation, DefaultMutation +from .mutation import BaseMutation, DefaultMutation, CombinedMutation, ModuleDuplication from .distance import BaseDistance, DefaultDistance diff --git a/src/tensorneat/genome/operations/mutation/__init__.py b/src/tensorneat/genome/operations/mutation/__init__.py index 2c12bea..1ca35a8 100644 --- a/src/tensorneat/genome/operations/mutation/__init__.py +++ b/src/tensorneat/genome/operations/mutation/__init__.py @@ -1,2 +1,3 @@ from .base import BaseMutation from .default import DefaultMutation +from .module_duplication import CombinedMutation, ModuleDuplication diff --git a/src/tensorneat/genome/operations/mutation/module_duplication.py b/src/tensorneat/genome/operations/mutation/module_duplication.py new file mode 100644 index 0000000..48f6e77 --- /dev/null +++ b/src/tensorneat/genome/operations/mutation/module_duplication.py @@ -0,0 +1,141 @@ +"""Module Duplication Mutation + +Implements the MEND *module duplication* operator. It duplicates a hidden +module (a hidden node together with all its incoming and outgoing connections). +The operator fires with probability ``duplication_rate``. +""" + +import jax +import jax.numpy as jnp +from .base import BaseMutation +from .default import DefaultMutation +from ...utils import add_node, add_conn, extract_gene_attrs +from tensorneat.common import I_INF + + +class ModuleDuplication(BaseMutation): + """Duplicate a hidden module (node + its connections). + + Parameters + ---------- + duplication_rate: float, optional + Probability that the duplication operator fires during a mutation + step. Must be in ``[0, 1]``. Default ``0.0`` (disabled). + """ + + def __init__(self, duplication_rate: float = 0.0): + self.duplication_rate = duplication_rate + self._helper = DefaultMutation() + + def _choose_node_key(self, key, nodes, input_idx, output_idx): + return self._helper.choose_node_key( + key, + nodes, + input_idx, + output_idx, + allow_input_keys=False, + allow_output_keys=False, + ) + + def __call__( + self, state, genome, randkey, nodes, conns, new_node_key, new_conn_key + ): + """Apply module duplication. + + If the random draw is lower than ``duplication_rate`` a hidden node is + selected and duplicated (its connections are copied). Otherwise the genome + is returned unchanged. + """ + fire = jax.random.uniform(randkey) < self.duplication_rate + + n_fixed = len(genome.conn_gene.fixed_attrs) + + def duplicate(_): + node_key, node_idx = self._choose_node_key( + randkey, nodes, genome.input_idx, genome.output_idx + ) + + def no_node(): + return nodes, conns + + def do_dup(): + # Add a new hidden node with identity attributes. + new_nodes = add_node( + nodes, + jnp.array([new_node_key]), + genome.node_gene.new_identity_attrs(state), + ) + + # Use jax.lax.scan to iterate over all connections and + # conditionally copy incoming ones (target == node_key). + def copy_incoming(carry, conn): + conns_acc = carry + is_incoming = (conn[1] == node_key) & ~jnp.isnan(conn[0]) + attrs = conn[n_fixed:] + src = conn[0] + fix = jnp.array([src, new_node_key]) + new_conn = jnp.concatenate((fix, attrs)) + pos = jnp.argmax(jnp.isnan(conns_acc[:, 0])) + has_space = jnp.isnan(conns_acc[pos, 0]) + should_add = is_incoming & has_space + conns_acc = jnp.where( + should_add, + conns_acc.at[pos].set(new_conn), + conns_acc, + ) + return conns_acc, None + + new_conns, _ = jax.lax.scan(copy_incoming, conns, conns) + + # Copy outgoing connections (source == node_key). + def copy_outgoing(carry, conn): + conns_acc = carry + is_outgoing = (conn[0] == node_key) & ~jnp.isnan(conn[0]) + attrs = conn[n_fixed:] + tgt = conn[1] + fix = jnp.array([new_node_key, tgt]) + new_conn = jnp.concatenate((fix, attrs)) + pos = jnp.argmax(jnp.isnan(conns_acc[:, 0])) + has_space = jnp.isnan(conns_acc[pos, 0]) + should_add = is_outgoing & has_space + conns_acc = jnp.where( + should_add, + conns_acc.at[pos].set(new_conn), + conns_acc, + ) + return conns_acc, None + + # Scan over the ORIGINAL conns for outgoing check, but + # accumulate into the already-updated new_conns. + new_conns, _ = jax.lax.scan(copy_outgoing, new_conns, conns) + + return new_nodes, new_conns + + return jax.lax.cond(node_idx == I_INF, no_node, do_dup) + + return jax.lax.cond(fire, duplicate, lambda _: (nodes, conns), operand=None) + + +class CombinedMutation(BaseMutation): + """Combine ``DefaultMutation`` with optional ``ModuleDuplication``. + + The combined mutation first runs the standard NEAT structural/value + mutations and then (with the configured ``duplication_rate``) may duplicate a + module. + """ + + def __init__(self, duplication_rate: float = 0.0, **default_kwargs): + self.default_mut = DefaultMutation(**default_kwargs) + self.dup_mut = ModuleDuplication(duplication_rate) + + def __call__( + self, state, genome, randkey, nodes, conns, new_node_key, new_conn_key + ): + k1, k2 = jax.random.split(randkey) + nodes, conns = self.default_mut( + state, genome, k1, nodes, conns, new_node_key, new_conn_key + ) + nodes, conns = self.dup_mut( + state, genome, k2, nodes, conns, new_node_key, new_conn_key + ) + return nodes, conns diff --git a/test/test_module_duplication.py b/test/test_module_duplication.py new file mode 100644 index 0000000..40cd1eb --- /dev/null +++ b/test/test_module_duplication.py @@ -0,0 +1,164 @@ +"""Tests for the MEND ModuleDuplication mutation. + +Tests verify that: +1. The ModuleDuplication operator alone adds exactly one hidden node. +2. The duplicated node has connections copied from the source. +3. CombinedMutation integrates both default and duplication mutations. +""" + +import jax +import jax.numpy as jnp +from tensorneat.genome import DefaultGenome +from tensorneat.genome.operations.mutation import ModuleDuplication + + +def _count_hidden(nodes, genome): + """Count non-NaN hidden nodes (excluding inputs/outputs).""" + return sum( + 1 + for n in nodes + if not jnp.isnan(n[0]) + and int(n[0]) not in set(genome.input_idx.tolist()) + and int(n[0]) not in set(genome.output_idx.tolist()) + ) + + +def _count_valid_conns(conns): + """Count non-NaN connections.""" + return sum(1 for c in conns if not jnp.isnan(c[0])) + + +def test_module_duplication_standalone(): + """Test ModuleDuplication in isolation (not via CombinedMutation). + + Start with a genome that has one hidden node (via init_hidden_layers), + then apply only ModuleDuplication with rate=1.0. + """ + genome = DefaultGenome( + num_inputs=2, + num_outputs=1, + max_nodes=10, + max_conns=30, + init_hidden_layers=(1,), # start with 1 hidden node + ) + state = genome.setup() + randkey = jax.random.PRNGKey(42) + nodes, conns = genome.initialize(state, randkey) + + init_hidden = _count_hidden(nodes, genome) + init_conns = _count_valid_conns(conns) + assert init_hidden == 1, f"Expected 1 initial hidden node, got {init_hidden}" + + # Apply only the duplication operator (not CombinedMutation). + dup = ModuleDuplication(duplication_rate=1.0) + k1, _ = jax.random.split(randkey) + new_node_key = jnp.array(genome.max_nodes + 1) + new_conn_key = jnp.array( + [genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3] + ) + nodes_mut, conns_mut = dup( + state, genome, k1, nodes, conns, new_node_key, new_conn_key + ) + + post_hidden = _count_hidden(nodes_mut, genome) + post_conns = _count_valid_conns(conns_mut) + + assert post_hidden == init_hidden + 1, ( + f"Module duplication should add exactly one hidden node, " + f"got {post_hidden} (was {init_hidden})" + ) + # The source node had 2 incoming (from inputs) + 1 outgoing (to output) = 3 conns. + # Duplication should copy all of them. + assert post_conns > init_conns, ( + f"Duplication should add connections, got {post_conns} (was {init_conns})" + ) + + +def test_duplicated_node_has_connections(): + """The duplicated node must have both incoming and outgoing connections.""" + genome = DefaultGenome( + num_inputs=2, + num_outputs=1, + max_nodes=10, + max_conns=30, + init_hidden_layers=(1,), + ) + state = genome.setup() + randkey = jax.random.PRNGKey(7) + nodes, conns = genome.initialize(state, randkey) + + dup = ModuleDuplication(duplication_rate=1.0) + k1, _ = jax.random.split(randkey) + new_node_key = jnp.array(genome.max_nodes + 1) + new_conn_key = jnp.array( + [genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3] + ) + nodes_mut, conns_mut = dup( + state, genome, k1, nodes, conns, new_node_key, new_conn_key + ) + + nk = float(new_node_key) + incoming = sum( + 1 for c in conns_mut if not jnp.isnan(c[0]) and float(c[1]) == nk + ) + outgoing = sum( + 1 for c in conns_mut if not jnp.isnan(c[0]) and float(c[0]) == nk + ) + assert incoming > 0, "Duplicated node must have incoming connections" + assert outgoing > 0, "Duplicated node must have outgoing connections" + + +def test_combined_mutation_with_duplication(): + """CombinedMutation should produce a valid genome (no crashes).""" + genome = DefaultGenome( + num_inputs=2, + num_outputs=1, + max_nodes=15, + max_conns=40, + duplication_rate=1.0, + ) + state = genome.setup() + randkey = jax.random.PRNGKey(99) + nodes, conns = genome.initialize(state, randkey) + + k1, _ = jax.random.split(randkey) + new_node_key = jnp.array(genome.max_nodes + 1) + new_conn_key = jnp.array( + [genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3] + ) + nodes_mut, conns_mut = genome.execute_mutation( + state, k1, nodes, conns, new_node_key, new_conn_key + ) + + # Should not crash and should have at least as many nodes/conns as before. + init_valid = sum(1 for n in nodes if not jnp.isnan(n[0])) + post_valid = sum(1 for n in nodes_mut if not jnp.isnan(n[0])) + assert post_valid >= init_valid, "Mutation should not lose nodes" + + +def test_duplication_disabled_when_rate_zero(): + """With duplication_rate=0.0, no duplication should occur.""" + genome = DefaultGenome( + num_inputs=2, + num_outputs=1, + max_nodes=10, + max_conns=30, + init_hidden_layers=(1,), + ) + state = genome.setup() + randkey = jax.random.PRNGKey(0) + nodes, conns = genome.initialize(state, randkey) + + dup = ModuleDuplication(duplication_rate=0.0) + k1, _ = jax.random.split(randkey) + new_node_key = jnp.array(genome.max_nodes + 1) + new_conn_key = jnp.array( + [genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3] + ) + nodes_mut, conns_mut = dup( + state, genome, k1, nodes, conns, new_node_key, new_conn_key + ) + + assert _count_hidden(nodes_mut, genome) == _count_hidden(nodes, genome), ( + "No duplication should occur with rate=0.0" + )