Add MEND module duplication operator
Implement the core MEND mechanism: a module duplication mutation that copies a hidden node with all its incoming and outgoing connections. This is the single addition MEND makes to standard NEAT. - ModuleDuplication: JAX-compatible operator using jax.lax.scan - CombinedMutation: composes DefaultMutation + ModuleDuplication - DefaultGenome: accepts duplication_rate parameter - Tests for standalone duplication, combined mutation, and rate=0 no-op Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
164
test/test_module_duplication.py
Normal file
164
test/test_module_duplication.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Tests for the MEND ModuleDuplication mutation.
|
||||
|
||||
Tests verify that:
|
||||
1. The ModuleDuplication operator alone adds exactly one hidden node.
|
||||
2. The duplicated node has connections copied from the source.
|
||||
3. CombinedMutation integrates both default and duplication mutations.
|
||||
"""
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from tensorneat.genome import DefaultGenome
|
||||
from tensorneat.genome.operations.mutation import ModuleDuplication
|
||||
|
||||
|
||||
def _count_hidden(nodes, genome):
|
||||
"""Count non-NaN hidden nodes (excluding inputs/outputs)."""
|
||||
return sum(
|
||||
1
|
||||
for n in nodes
|
||||
if not jnp.isnan(n[0])
|
||||
and int(n[0]) not in set(genome.input_idx.tolist())
|
||||
and int(n[0]) not in set(genome.output_idx.tolist())
|
||||
)
|
||||
|
||||
|
||||
def _count_valid_conns(conns):
|
||||
"""Count non-NaN connections."""
|
||||
return sum(1 for c in conns if not jnp.isnan(c[0]))
|
||||
|
||||
|
||||
def test_module_duplication_standalone():
|
||||
"""Test ModuleDuplication in isolation (not via CombinedMutation).
|
||||
|
||||
Start with a genome that has one hidden node (via init_hidden_layers),
|
||||
then apply only ModuleDuplication with rate=1.0.
|
||||
"""
|
||||
genome = DefaultGenome(
|
||||
num_inputs=2,
|
||||
num_outputs=1,
|
||||
max_nodes=10,
|
||||
max_conns=30,
|
||||
init_hidden_layers=(1,), # start with 1 hidden node
|
||||
)
|
||||
state = genome.setup()
|
||||
randkey = jax.random.PRNGKey(42)
|
||||
nodes, conns = genome.initialize(state, randkey)
|
||||
|
||||
init_hidden = _count_hidden(nodes, genome)
|
||||
init_conns = _count_valid_conns(conns)
|
||||
assert init_hidden == 1, f"Expected 1 initial hidden node, got {init_hidden}"
|
||||
|
||||
# Apply only the duplication operator (not CombinedMutation).
|
||||
dup = ModuleDuplication(duplication_rate=1.0)
|
||||
k1, _ = jax.random.split(randkey)
|
||||
new_node_key = jnp.array(genome.max_nodes + 1)
|
||||
new_conn_key = jnp.array(
|
||||
[genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3]
|
||||
)
|
||||
nodes_mut, conns_mut = dup(
|
||||
state, genome, k1, nodes, conns, new_node_key, new_conn_key
|
||||
)
|
||||
|
||||
post_hidden = _count_hidden(nodes_mut, genome)
|
||||
post_conns = _count_valid_conns(conns_mut)
|
||||
|
||||
assert post_hidden == init_hidden + 1, (
|
||||
f"Module duplication should add exactly one hidden node, "
|
||||
f"got {post_hidden} (was {init_hidden})"
|
||||
)
|
||||
# The source node had 2 incoming (from inputs) + 1 outgoing (to output) = 3 conns.
|
||||
# Duplication should copy all of them.
|
||||
assert post_conns > init_conns, (
|
||||
f"Duplication should add connections, got {post_conns} (was {init_conns})"
|
||||
)
|
||||
|
||||
|
||||
def test_duplicated_node_has_connections():
|
||||
"""The duplicated node must have both incoming and outgoing connections."""
|
||||
genome = DefaultGenome(
|
||||
num_inputs=2,
|
||||
num_outputs=1,
|
||||
max_nodes=10,
|
||||
max_conns=30,
|
||||
init_hidden_layers=(1,),
|
||||
)
|
||||
state = genome.setup()
|
||||
randkey = jax.random.PRNGKey(7)
|
||||
nodes, conns = genome.initialize(state, randkey)
|
||||
|
||||
dup = ModuleDuplication(duplication_rate=1.0)
|
||||
k1, _ = jax.random.split(randkey)
|
||||
new_node_key = jnp.array(genome.max_nodes + 1)
|
||||
new_conn_key = jnp.array(
|
||||
[genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3]
|
||||
)
|
||||
nodes_mut, conns_mut = dup(
|
||||
state, genome, k1, nodes, conns, new_node_key, new_conn_key
|
||||
)
|
||||
|
||||
nk = float(new_node_key)
|
||||
incoming = sum(
|
||||
1 for c in conns_mut if not jnp.isnan(c[0]) and float(c[1]) == nk
|
||||
)
|
||||
outgoing = sum(
|
||||
1 for c in conns_mut if not jnp.isnan(c[0]) and float(c[0]) == nk
|
||||
)
|
||||
assert incoming > 0, "Duplicated node must have incoming connections"
|
||||
assert outgoing > 0, "Duplicated node must have outgoing connections"
|
||||
|
||||
|
||||
def test_combined_mutation_with_duplication():
|
||||
"""CombinedMutation should produce a valid genome (no crashes)."""
|
||||
genome = DefaultGenome(
|
||||
num_inputs=2,
|
||||
num_outputs=1,
|
||||
max_nodes=15,
|
||||
max_conns=40,
|
||||
duplication_rate=1.0,
|
||||
)
|
||||
state = genome.setup()
|
||||
randkey = jax.random.PRNGKey(99)
|
||||
nodes, conns = genome.initialize(state, randkey)
|
||||
|
||||
k1, _ = jax.random.split(randkey)
|
||||
new_node_key = jnp.array(genome.max_nodes + 1)
|
||||
new_conn_key = jnp.array(
|
||||
[genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3]
|
||||
)
|
||||
nodes_mut, conns_mut = genome.execute_mutation(
|
||||
state, k1, nodes, conns, new_node_key, new_conn_key
|
||||
)
|
||||
|
||||
# Should not crash and should have at least as many nodes/conns as before.
|
||||
init_valid = sum(1 for n in nodes if not jnp.isnan(n[0]))
|
||||
post_valid = sum(1 for n in nodes_mut if not jnp.isnan(n[0]))
|
||||
assert post_valid >= init_valid, "Mutation should not lose nodes"
|
||||
|
||||
|
||||
def test_duplication_disabled_when_rate_zero():
|
||||
"""With duplication_rate=0.0, no duplication should occur."""
|
||||
genome = DefaultGenome(
|
||||
num_inputs=2,
|
||||
num_outputs=1,
|
||||
max_nodes=10,
|
||||
max_conns=30,
|
||||
init_hidden_layers=(1,),
|
||||
)
|
||||
state = genome.setup()
|
||||
randkey = jax.random.PRNGKey(0)
|
||||
nodes, conns = genome.initialize(state, randkey)
|
||||
|
||||
dup = ModuleDuplication(duplication_rate=0.0)
|
||||
k1, _ = jax.random.split(randkey)
|
||||
new_node_key = jnp.array(genome.max_nodes + 1)
|
||||
new_conn_key = jnp.array(
|
||||
[genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3]
|
||||
)
|
||||
nodes_mut, conns_mut = dup(
|
||||
state, genome, k1, nodes, conns, new_node_key, new_conn_key
|
||||
)
|
||||
|
||||
assert _count_hidden(nodes_mut, genome) == _count_hidden(nodes, genome), (
|
||||
"No duplication should occur with rate=0.0"
|
||||
)
|
||||
Reference in New Issue
Block a user