Add MEND module duplication operator

Implement the core MEND mechanism: a module duplication mutation that
copies a hidden node with all its incoming and outgoing connections.
This is the single addition MEND makes to standard NEAT.

- ModuleDuplication: JAX-compatible operator using jax.lax.scan
- CombinedMutation: composes DefaultMutation + ModuleDuplication
- DefaultGenome: accepts duplication_rate parameter
- Tests for standalone duplication, combined mutation, and rate=0 no-op

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
anthonyrawlins
2026-02-28 10:40:54 +11:00
parent 7e872c7191
commit ec48447670
5 changed files with 322 additions and 12 deletions

View File

@@ -0,0 +1,164 @@
"""Tests for the MEND ModuleDuplication mutation.
Tests verify that:
1. The ModuleDuplication operator alone adds exactly one hidden node.
2. The duplicated node has connections copied from the source.
3. CombinedMutation integrates both default and duplication mutations.
"""
import jax
import jax.numpy as jnp
from tensorneat.genome import DefaultGenome
from tensorneat.genome.operations.mutation import ModuleDuplication
def _count_hidden(nodes, genome):
"""Count non-NaN hidden nodes (excluding inputs/outputs)."""
return sum(
1
for n in nodes
if not jnp.isnan(n[0])
and int(n[0]) not in set(genome.input_idx.tolist())
and int(n[0]) not in set(genome.output_idx.tolist())
)
def _count_valid_conns(conns):
"""Count non-NaN connections."""
return sum(1 for c in conns if not jnp.isnan(c[0]))
def test_module_duplication_standalone():
"""Test ModuleDuplication in isolation (not via CombinedMutation).
Start with a genome that has one hidden node (via init_hidden_layers),
then apply only ModuleDuplication with rate=1.0.
"""
genome = DefaultGenome(
num_inputs=2,
num_outputs=1,
max_nodes=10,
max_conns=30,
init_hidden_layers=(1,), # start with 1 hidden node
)
state = genome.setup()
randkey = jax.random.PRNGKey(42)
nodes, conns = genome.initialize(state, randkey)
init_hidden = _count_hidden(nodes, genome)
init_conns = _count_valid_conns(conns)
assert init_hidden == 1, f"Expected 1 initial hidden node, got {init_hidden}"
# Apply only the duplication operator (not CombinedMutation).
dup = ModuleDuplication(duplication_rate=1.0)
k1, _ = jax.random.split(randkey)
new_node_key = jnp.array(genome.max_nodes + 1)
new_conn_key = jnp.array(
[genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3]
)
nodes_mut, conns_mut = dup(
state, genome, k1, nodes, conns, new_node_key, new_conn_key
)
post_hidden = _count_hidden(nodes_mut, genome)
post_conns = _count_valid_conns(conns_mut)
assert post_hidden == init_hidden + 1, (
f"Module duplication should add exactly one hidden node, "
f"got {post_hidden} (was {init_hidden})"
)
# The source node had 2 incoming (from inputs) + 1 outgoing (to output) = 3 conns.
# Duplication should copy all of them.
assert post_conns > init_conns, (
f"Duplication should add connections, got {post_conns} (was {init_conns})"
)
def test_duplicated_node_has_connections():
"""The duplicated node must have both incoming and outgoing connections."""
genome = DefaultGenome(
num_inputs=2,
num_outputs=1,
max_nodes=10,
max_conns=30,
init_hidden_layers=(1,),
)
state = genome.setup()
randkey = jax.random.PRNGKey(7)
nodes, conns = genome.initialize(state, randkey)
dup = ModuleDuplication(duplication_rate=1.0)
k1, _ = jax.random.split(randkey)
new_node_key = jnp.array(genome.max_nodes + 1)
new_conn_key = jnp.array(
[genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3]
)
nodes_mut, conns_mut = dup(
state, genome, k1, nodes, conns, new_node_key, new_conn_key
)
nk = float(new_node_key)
incoming = sum(
1 for c in conns_mut if not jnp.isnan(c[0]) and float(c[1]) == nk
)
outgoing = sum(
1 for c in conns_mut if not jnp.isnan(c[0]) and float(c[0]) == nk
)
assert incoming > 0, "Duplicated node must have incoming connections"
assert outgoing > 0, "Duplicated node must have outgoing connections"
def test_combined_mutation_with_duplication():
"""CombinedMutation should produce a valid genome (no crashes)."""
genome = DefaultGenome(
num_inputs=2,
num_outputs=1,
max_nodes=15,
max_conns=40,
duplication_rate=1.0,
)
state = genome.setup()
randkey = jax.random.PRNGKey(99)
nodes, conns = genome.initialize(state, randkey)
k1, _ = jax.random.split(randkey)
new_node_key = jnp.array(genome.max_nodes + 1)
new_conn_key = jnp.array(
[genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3]
)
nodes_mut, conns_mut = genome.execute_mutation(
state, k1, nodes, conns, new_node_key, new_conn_key
)
# Should not crash and should have at least as many nodes/conns as before.
init_valid = sum(1 for n in nodes if not jnp.isnan(n[0]))
post_valid = sum(1 for n in nodes_mut if not jnp.isnan(n[0]))
assert post_valid >= init_valid, "Mutation should not lose nodes"
def test_duplication_disabled_when_rate_zero():
"""With duplication_rate=0.0, no duplication should occur."""
genome = DefaultGenome(
num_inputs=2,
num_outputs=1,
max_nodes=10,
max_conns=30,
init_hidden_layers=(1,),
)
state = genome.setup()
randkey = jax.random.PRNGKey(0)
nodes, conns = genome.initialize(state, randkey)
dup = ModuleDuplication(duplication_rate=0.0)
k1, _ = jax.random.split(randkey)
new_node_key = jnp.array(genome.max_nodes + 1)
new_conn_key = jnp.array(
[genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3]
)
nodes_mut, conns_mut = dup(
state, genome, k1, nodes, conns, new_node_key, new_conn_key
)
assert _count_hidden(nodes_mut, genome) == _count_hidden(nodes, genome), (
"No duplication should occur with rate=0.0"
)