Add MEND module duplication operator

Implement the core MEND mechanism: a module duplication mutation that
copies a hidden node with all its incoming and outgoing connections.
This is the single addition MEND makes to standard NEAT.

- ModuleDuplication: JAX-compatible operator using jax.lax.scan
- CombinedMutation: composes DefaultMutation + ModuleDuplication
- DefaultGenome: accepts duplication_rate parameter
- Tests for standalone duplication, combined mutation, and rate=0 no-op

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
anthonyrawlins
2026-02-28 10:40:54 +11:00
parent 7e872c7191
commit ec48447670
5 changed files with 322 additions and 12 deletions

View File

@@ -7,7 +7,12 @@ import sympy as sp
from .base import BaseGenome
from .gene import DefaultNode, DefaultConn
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
from .operations import (
DefaultMutation,
DefaultCrossover,
DefaultDistance,
CombinedMutation,
)
from .utils import unflatten_conns, extract_gene_attrs, extract_gene_attrs
from tensorneat.common import (
@@ -40,8 +45,12 @@ class DefaultGenome(BaseGenome):
output_transform=None,
input_transform=None,
init_hidden_layers=(),
duplication_rate: float = 0.0,
):
# If a duplication_rate is specified, wrap the default mutation with the CombinedMutation
if duplication_rate > 0.0:
# Use CombinedMutation which includes standard mutation plus module duplication
mutation = CombinedMutation(duplication_rate=duplication_rate)
super().__init__(
num_inputs,
num_outputs,
@@ -66,7 +75,6 @@ class DefaultGenome(BaseGenome):
return seqs, nodes, conns, u_conns
def forward(self, state, transformed, inputs):
if self.input_transform is not None:
inputs = self.input_transform(inputs)
@@ -133,9 +141,7 @@ class DefaultGenome(BaseGenome):
network["topo_order"] = topo_order
network["topo_layers"] = topo_layers
network["useful_nodes"] = find_useful_nodes(
set(network["nodes"]),
set(network["conns"]),
set(self.output_idx)
set(network["nodes"]), set(network["conns"]), set(self.output_idx)
)
return network
@@ -147,7 +153,6 @@ class DefaultGenome(BaseGenome):
sympy_output_transform=None,
backend="jax",
):
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
if sympy_input_transform is None and self.input_transform is not None:
@@ -187,7 +192,6 @@ class DefaultGenome(BaseGenome):
nodes_exprs = {}
args_symbols = {}
for i in order:
if i in input_idx:
nodes_exprs[symbols[-i - 1]] = symbols[
-i - 1
@@ -258,7 +262,7 @@ class DefaultGenome(BaseGenome):
output_exprs,
forward_func,
network["topo_order"],
network["useful_nodes"]
network["useful_nodes"],
)
def visualize(
@@ -298,7 +302,7 @@ class DefaultGenome(BaseGenome):
nodes_y = []
for node in nodes:
node_y = 0
for y, last_node in enumerate(topo_layers[layer-1]):
for y, last_node in enumerate(topo_layers[layer - 1]):
if (last_node, node) in conns_list:
node_y += y
nodes_y.append(node_y)
@@ -395,7 +399,7 @@ class DefaultGenome(BaseGenome):
edgecolors=edgecolors,
arrowstyle=arrowstyle,
arrowsize=arrowsize,
edge_color=edge_color
edge_color=edge_color,
)
plt.savefig(save_path, dpi=save_dpi)
plt.close()

View File

@@ -1,3 +1,3 @@
from .crossover import BaseCrossover, DefaultCrossover
from .mutation import BaseMutation, DefaultMutation
from .mutation import BaseMutation, DefaultMutation, CombinedMutation, ModuleDuplication
from .distance import BaseDistance, DefaultDistance

View File

@@ -1,2 +1,3 @@
from .base import BaseMutation
from .default import DefaultMutation
from .module_duplication import CombinedMutation, ModuleDuplication

View File

@@ -0,0 +1,141 @@
"""Module Duplication Mutation
Implements the MEND *module duplication* operator. It duplicates a hidden
module (a hidden node together with all its incoming and outgoing connections).
The operator fires with probability ``duplication_rate``.
"""
import jax
import jax.numpy as jnp
from .base import BaseMutation
from .default import DefaultMutation
from ...utils import add_node, add_conn, extract_gene_attrs
from tensorneat.common import I_INF
class ModuleDuplication(BaseMutation):
"""Duplicate a hidden module (node + its connections).
Parameters
----------
duplication_rate: float, optional
Probability that the duplication operator fires during a mutation
step. Must be in ``[0, 1]``. Default ``0.0`` (disabled).
"""
def __init__(self, duplication_rate: float = 0.0):
self.duplication_rate = duplication_rate
self._helper = DefaultMutation()
def _choose_node_key(self, key, nodes, input_idx, output_idx):
return self._helper.choose_node_key(
key,
nodes,
input_idx,
output_idx,
allow_input_keys=False,
allow_output_keys=False,
)
def __call__(
self, state, genome, randkey, nodes, conns, new_node_key, new_conn_key
):
"""Apply module duplication.
If the random draw is lower than ``duplication_rate`` a hidden node is
selected and duplicated (its connections are copied). Otherwise the genome
is returned unchanged.
"""
fire = jax.random.uniform(randkey) < self.duplication_rate
n_fixed = len(genome.conn_gene.fixed_attrs)
def duplicate(_):
node_key, node_idx = self._choose_node_key(
randkey, nodes, genome.input_idx, genome.output_idx
)
def no_node():
return nodes, conns
def do_dup():
# Add a new hidden node with identity attributes.
new_nodes = add_node(
nodes,
jnp.array([new_node_key]),
genome.node_gene.new_identity_attrs(state),
)
# Use jax.lax.scan to iterate over all connections and
# conditionally copy incoming ones (target == node_key).
def copy_incoming(carry, conn):
conns_acc = carry
is_incoming = (conn[1] == node_key) & ~jnp.isnan(conn[0])
attrs = conn[n_fixed:]
src = conn[0]
fix = jnp.array([src, new_node_key])
new_conn = jnp.concatenate((fix, attrs))
pos = jnp.argmax(jnp.isnan(conns_acc[:, 0]))
has_space = jnp.isnan(conns_acc[pos, 0])
should_add = is_incoming & has_space
conns_acc = jnp.where(
should_add,
conns_acc.at[pos].set(new_conn),
conns_acc,
)
return conns_acc, None
new_conns, _ = jax.lax.scan(copy_incoming, conns, conns)
# Copy outgoing connections (source == node_key).
def copy_outgoing(carry, conn):
conns_acc = carry
is_outgoing = (conn[0] == node_key) & ~jnp.isnan(conn[0])
attrs = conn[n_fixed:]
tgt = conn[1]
fix = jnp.array([new_node_key, tgt])
new_conn = jnp.concatenate((fix, attrs))
pos = jnp.argmax(jnp.isnan(conns_acc[:, 0]))
has_space = jnp.isnan(conns_acc[pos, 0])
should_add = is_outgoing & has_space
conns_acc = jnp.where(
should_add,
conns_acc.at[pos].set(new_conn),
conns_acc,
)
return conns_acc, None
# Scan over the ORIGINAL conns for outgoing check, but
# accumulate into the already-updated new_conns.
new_conns, _ = jax.lax.scan(copy_outgoing, new_conns, conns)
return new_nodes, new_conns
return jax.lax.cond(node_idx == I_INF, no_node, do_dup)
return jax.lax.cond(fire, duplicate, lambda _: (nodes, conns), operand=None)
class CombinedMutation(BaseMutation):
"""Combine ``DefaultMutation`` with optional ``ModuleDuplication``.
The combined mutation first runs the standard NEAT structural/value
mutations and then (with the configured ``duplication_rate``) may duplicate a
module.
"""
def __init__(self, duplication_rate: float = 0.0, **default_kwargs):
self.default_mut = DefaultMutation(**default_kwargs)
self.dup_mut = ModuleDuplication(duplication_rate)
def __call__(
self, state, genome, randkey, nodes, conns, new_node_key, new_conn_key
):
k1, k2 = jax.random.split(randkey)
nodes, conns = self.default_mut(
state, genome, k1, nodes, conns, new_node_key, new_conn_key
)
nodes, conns = self.dup_mut(
state, genome, k2, nodes, conns, new_node_key, new_conn_key
)
return nodes, conns

View File

@@ -0,0 +1,164 @@
"""Tests for the MEND ModuleDuplication mutation.
Tests verify that:
1. The ModuleDuplication operator alone adds exactly one hidden node.
2. The duplicated node has connections copied from the source.
3. CombinedMutation integrates both default and duplication mutations.
"""
import jax
import jax.numpy as jnp
from tensorneat.genome import DefaultGenome
from tensorneat.genome.operations.mutation import ModuleDuplication
def _count_hidden(nodes, genome):
"""Count non-NaN hidden nodes (excluding inputs/outputs)."""
return sum(
1
for n in nodes
if not jnp.isnan(n[0])
and int(n[0]) not in set(genome.input_idx.tolist())
and int(n[0]) not in set(genome.output_idx.tolist())
)
def _count_valid_conns(conns):
"""Count non-NaN connections."""
return sum(1 for c in conns if not jnp.isnan(c[0]))
def test_module_duplication_standalone():
"""Test ModuleDuplication in isolation (not via CombinedMutation).
Start with a genome that has one hidden node (via init_hidden_layers),
then apply only ModuleDuplication with rate=1.0.
"""
genome = DefaultGenome(
num_inputs=2,
num_outputs=1,
max_nodes=10,
max_conns=30,
init_hidden_layers=(1,), # start with 1 hidden node
)
state = genome.setup()
randkey = jax.random.PRNGKey(42)
nodes, conns = genome.initialize(state, randkey)
init_hidden = _count_hidden(nodes, genome)
init_conns = _count_valid_conns(conns)
assert init_hidden == 1, f"Expected 1 initial hidden node, got {init_hidden}"
# Apply only the duplication operator (not CombinedMutation).
dup = ModuleDuplication(duplication_rate=1.0)
k1, _ = jax.random.split(randkey)
new_node_key = jnp.array(genome.max_nodes + 1)
new_conn_key = jnp.array(
[genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3]
)
nodes_mut, conns_mut = dup(
state, genome, k1, nodes, conns, new_node_key, new_conn_key
)
post_hidden = _count_hidden(nodes_mut, genome)
post_conns = _count_valid_conns(conns_mut)
assert post_hidden == init_hidden + 1, (
f"Module duplication should add exactly one hidden node, "
f"got {post_hidden} (was {init_hidden})"
)
# The source node had 2 incoming (from inputs) + 1 outgoing (to output) = 3 conns.
# Duplication should copy all of them.
assert post_conns > init_conns, (
f"Duplication should add connections, got {post_conns} (was {init_conns})"
)
def test_duplicated_node_has_connections():
"""The duplicated node must have both incoming and outgoing connections."""
genome = DefaultGenome(
num_inputs=2,
num_outputs=1,
max_nodes=10,
max_conns=30,
init_hidden_layers=(1,),
)
state = genome.setup()
randkey = jax.random.PRNGKey(7)
nodes, conns = genome.initialize(state, randkey)
dup = ModuleDuplication(duplication_rate=1.0)
k1, _ = jax.random.split(randkey)
new_node_key = jnp.array(genome.max_nodes + 1)
new_conn_key = jnp.array(
[genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3]
)
nodes_mut, conns_mut = dup(
state, genome, k1, nodes, conns, new_node_key, new_conn_key
)
nk = float(new_node_key)
incoming = sum(
1 for c in conns_mut if not jnp.isnan(c[0]) and float(c[1]) == nk
)
outgoing = sum(
1 for c in conns_mut if not jnp.isnan(c[0]) and float(c[0]) == nk
)
assert incoming > 0, "Duplicated node must have incoming connections"
assert outgoing > 0, "Duplicated node must have outgoing connections"
def test_combined_mutation_with_duplication():
"""CombinedMutation should produce a valid genome (no crashes)."""
genome = DefaultGenome(
num_inputs=2,
num_outputs=1,
max_nodes=15,
max_conns=40,
duplication_rate=1.0,
)
state = genome.setup()
randkey = jax.random.PRNGKey(99)
nodes, conns = genome.initialize(state, randkey)
k1, _ = jax.random.split(randkey)
new_node_key = jnp.array(genome.max_nodes + 1)
new_conn_key = jnp.array(
[genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3]
)
nodes_mut, conns_mut = genome.execute_mutation(
state, k1, nodes, conns, new_node_key, new_conn_key
)
# Should not crash and should have at least as many nodes/conns as before.
init_valid = sum(1 for n in nodes if not jnp.isnan(n[0]))
post_valid = sum(1 for n in nodes_mut if not jnp.isnan(n[0]))
assert post_valid >= init_valid, "Mutation should not lose nodes"
def test_duplication_disabled_when_rate_zero():
"""With duplication_rate=0.0, no duplication should occur."""
genome = DefaultGenome(
num_inputs=2,
num_outputs=1,
max_nodes=10,
max_conns=30,
init_hidden_layers=(1,),
)
state = genome.setup()
randkey = jax.random.PRNGKey(0)
nodes, conns = genome.initialize(state, randkey)
dup = ModuleDuplication(duplication_rate=0.0)
k1, _ = jax.random.split(randkey)
new_node_key = jnp.array(genome.max_nodes + 1)
new_conn_key = jnp.array(
[genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3]
)
nodes_mut, conns_mut = dup(
state, genome, k1, nodes, conns, new_node_key, new_conn_key
)
assert _count_hidden(nodes_mut, genome) == _count_hidden(nodes, genome), (
"No duplication should occur with rate=0.0"
)