Add MEND module duplication operator
Implement the core MEND mechanism: a module duplication mutation that copies a hidden node with all its incoming and outgoing connections. This is the single addition MEND makes to standard NEAT. - ModuleDuplication: JAX-compatible operator using jax.lax.scan - CombinedMutation: composes DefaultMutation + ModuleDuplication - DefaultGenome: accepts duplication_rate parameter - Tests for standalone duplication, combined mutation, and rate=0 no-op Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -7,7 +7,12 @@ import sympy as sp
|
||||
|
||||
from .base import BaseGenome
|
||||
from .gene import DefaultNode, DefaultConn
|
||||
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
|
||||
from .operations import (
|
||||
DefaultMutation,
|
||||
DefaultCrossover,
|
||||
DefaultDistance,
|
||||
CombinedMutation,
|
||||
)
|
||||
from .utils import unflatten_conns, extract_gene_attrs, extract_gene_attrs
|
||||
|
||||
from tensorneat.common import (
|
||||
@@ -40,8 +45,12 @@ class DefaultGenome(BaseGenome):
|
||||
output_transform=None,
|
||||
input_transform=None,
|
||||
init_hidden_layers=(),
|
||||
duplication_rate: float = 0.0,
|
||||
):
|
||||
|
||||
# If a duplication_rate is specified, wrap the default mutation with the CombinedMutation
|
||||
if duplication_rate > 0.0:
|
||||
# Use CombinedMutation which includes standard mutation plus module duplication
|
||||
mutation = CombinedMutation(duplication_rate=duplication_rate)
|
||||
super().__init__(
|
||||
num_inputs,
|
||||
num_outputs,
|
||||
@@ -66,7 +75,6 @@ class DefaultGenome(BaseGenome):
|
||||
return seqs, nodes, conns, u_conns
|
||||
|
||||
def forward(self, state, transformed, inputs):
|
||||
|
||||
if self.input_transform is not None:
|
||||
inputs = self.input_transform(inputs)
|
||||
|
||||
@@ -133,9 +141,7 @@ class DefaultGenome(BaseGenome):
|
||||
network["topo_order"] = topo_order
|
||||
network["topo_layers"] = topo_layers
|
||||
network["useful_nodes"] = find_useful_nodes(
|
||||
set(network["nodes"]),
|
||||
set(network["conns"]),
|
||||
set(self.output_idx)
|
||||
set(network["nodes"]), set(network["conns"]), set(self.output_idx)
|
||||
)
|
||||
return network
|
||||
|
||||
@@ -147,7 +153,6 @@ class DefaultGenome(BaseGenome):
|
||||
sympy_output_transform=None,
|
||||
backend="jax",
|
||||
):
|
||||
|
||||
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
|
||||
|
||||
if sympy_input_transform is None and self.input_transform is not None:
|
||||
@@ -187,7 +192,6 @@ class DefaultGenome(BaseGenome):
|
||||
nodes_exprs = {}
|
||||
args_symbols = {}
|
||||
for i in order:
|
||||
|
||||
if i in input_idx:
|
||||
nodes_exprs[symbols[-i - 1]] = symbols[
|
||||
-i - 1
|
||||
@@ -258,7 +262,7 @@ class DefaultGenome(BaseGenome):
|
||||
output_exprs,
|
||||
forward_func,
|
||||
network["topo_order"],
|
||||
network["useful_nodes"]
|
||||
network["useful_nodes"],
|
||||
)
|
||||
|
||||
def visualize(
|
||||
@@ -395,7 +399,7 @@ class DefaultGenome(BaseGenome):
|
||||
edgecolors=edgecolors,
|
||||
arrowstyle=arrowstyle,
|
||||
arrowsize=arrowsize,
|
||||
edge_color=edge_color
|
||||
edge_color=edge_color,
|
||||
)
|
||||
plt.savefig(save_path, dpi=save_dpi)
|
||||
plt.close()
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .crossover import BaseCrossover, DefaultCrossover
|
||||
from .mutation import BaseMutation, DefaultMutation
|
||||
from .mutation import BaseMutation, DefaultMutation, CombinedMutation, ModuleDuplication
|
||||
from .distance import BaseDistance, DefaultDistance
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from .base import BaseMutation
|
||||
from .default import DefaultMutation
|
||||
from .module_duplication import CombinedMutation, ModuleDuplication
|
||||
|
||||
141
src/tensorneat/genome/operations/mutation/module_duplication.py
Normal file
141
src/tensorneat/genome/operations/mutation/module_duplication.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Module Duplication Mutation
|
||||
|
||||
Implements the MEND *module duplication* operator. It duplicates a hidden
|
||||
module (a hidden node together with all its incoming and outgoing connections).
|
||||
The operator fires with probability ``duplication_rate``.
|
||||
"""
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from .base import BaseMutation
|
||||
from .default import DefaultMutation
|
||||
from ...utils import add_node, add_conn, extract_gene_attrs
|
||||
from tensorneat.common import I_INF
|
||||
|
||||
|
||||
class ModuleDuplication(BaseMutation):
|
||||
"""Duplicate a hidden module (node + its connections).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
duplication_rate: float, optional
|
||||
Probability that the duplication operator fires during a mutation
|
||||
step. Must be in ``[0, 1]``. Default ``0.0`` (disabled).
|
||||
"""
|
||||
|
||||
def __init__(self, duplication_rate: float = 0.0):
|
||||
self.duplication_rate = duplication_rate
|
||||
self._helper = DefaultMutation()
|
||||
|
||||
def _choose_node_key(self, key, nodes, input_idx, output_idx):
|
||||
return self._helper.choose_node_key(
|
||||
key,
|
||||
nodes,
|
||||
input_idx,
|
||||
output_idx,
|
||||
allow_input_keys=False,
|
||||
allow_output_keys=False,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, state, genome, randkey, nodes, conns, new_node_key, new_conn_key
|
||||
):
|
||||
"""Apply module duplication.
|
||||
|
||||
If the random draw is lower than ``duplication_rate`` a hidden node is
|
||||
selected and duplicated (its connections are copied). Otherwise the genome
|
||||
is returned unchanged.
|
||||
"""
|
||||
fire = jax.random.uniform(randkey) < self.duplication_rate
|
||||
|
||||
n_fixed = len(genome.conn_gene.fixed_attrs)
|
||||
|
||||
def duplicate(_):
|
||||
node_key, node_idx = self._choose_node_key(
|
||||
randkey, nodes, genome.input_idx, genome.output_idx
|
||||
)
|
||||
|
||||
def no_node():
|
||||
return nodes, conns
|
||||
|
||||
def do_dup():
|
||||
# Add a new hidden node with identity attributes.
|
||||
new_nodes = add_node(
|
||||
nodes,
|
||||
jnp.array([new_node_key]),
|
||||
genome.node_gene.new_identity_attrs(state),
|
||||
)
|
||||
|
||||
# Use jax.lax.scan to iterate over all connections and
|
||||
# conditionally copy incoming ones (target == node_key).
|
||||
def copy_incoming(carry, conn):
|
||||
conns_acc = carry
|
||||
is_incoming = (conn[1] == node_key) & ~jnp.isnan(conn[0])
|
||||
attrs = conn[n_fixed:]
|
||||
src = conn[0]
|
||||
fix = jnp.array([src, new_node_key])
|
||||
new_conn = jnp.concatenate((fix, attrs))
|
||||
pos = jnp.argmax(jnp.isnan(conns_acc[:, 0]))
|
||||
has_space = jnp.isnan(conns_acc[pos, 0])
|
||||
should_add = is_incoming & has_space
|
||||
conns_acc = jnp.where(
|
||||
should_add,
|
||||
conns_acc.at[pos].set(new_conn),
|
||||
conns_acc,
|
||||
)
|
||||
return conns_acc, None
|
||||
|
||||
new_conns, _ = jax.lax.scan(copy_incoming, conns, conns)
|
||||
|
||||
# Copy outgoing connections (source == node_key).
|
||||
def copy_outgoing(carry, conn):
|
||||
conns_acc = carry
|
||||
is_outgoing = (conn[0] == node_key) & ~jnp.isnan(conn[0])
|
||||
attrs = conn[n_fixed:]
|
||||
tgt = conn[1]
|
||||
fix = jnp.array([new_node_key, tgt])
|
||||
new_conn = jnp.concatenate((fix, attrs))
|
||||
pos = jnp.argmax(jnp.isnan(conns_acc[:, 0]))
|
||||
has_space = jnp.isnan(conns_acc[pos, 0])
|
||||
should_add = is_outgoing & has_space
|
||||
conns_acc = jnp.where(
|
||||
should_add,
|
||||
conns_acc.at[pos].set(new_conn),
|
||||
conns_acc,
|
||||
)
|
||||
return conns_acc, None
|
||||
|
||||
# Scan over the ORIGINAL conns for outgoing check, but
|
||||
# accumulate into the already-updated new_conns.
|
||||
new_conns, _ = jax.lax.scan(copy_outgoing, new_conns, conns)
|
||||
|
||||
return new_nodes, new_conns
|
||||
|
||||
return jax.lax.cond(node_idx == I_INF, no_node, do_dup)
|
||||
|
||||
return jax.lax.cond(fire, duplicate, lambda _: (nodes, conns), operand=None)
|
||||
|
||||
|
||||
class CombinedMutation(BaseMutation):
|
||||
"""Combine ``DefaultMutation`` with optional ``ModuleDuplication``.
|
||||
|
||||
The combined mutation first runs the standard NEAT structural/value
|
||||
mutations and then (with the configured ``duplication_rate``) may duplicate a
|
||||
module.
|
||||
"""
|
||||
|
||||
def __init__(self, duplication_rate: float = 0.0, **default_kwargs):
|
||||
self.default_mut = DefaultMutation(**default_kwargs)
|
||||
self.dup_mut = ModuleDuplication(duplication_rate)
|
||||
|
||||
def __call__(
|
||||
self, state, genome, randkey, nodes, conns, new_node_key, new_conn_key
|
||||
):
|
||||
k1, k2 = jax.random.split(randkey)
|
||||
nodes, conns = self.default_mut(
|
||||
state, genome, k1, nodes, conns, new_node_key, new_conn_key
|
||||
)
|
||||
nodes, conns = self.dup_mut(
|
||||
state, genome, k2, nodes, conns, new_node_key, new_conn_key
|
||||
)
|
||||
return nodes, conns
|
||||
164
test/test_module_duplication.py
Normal file
164
test/test_module_duplication.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Tests for the MEND ModuleDuplication mutation.
|
||||
|
||||
Tests verify that:
|
||||
1. The ModuleDuplication operator alone adds exactly one hidden node.
|
||||
2. The duplicated node has connections copied from the source.
|
||||
3. CombinedMutation integrates both default and duplication mutations.
|
||||
"""
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from tensorneat.genome import DefaultGenome
|
||||
from tensorneat.genome.operations.mutation import ModuleDuplication
|
||||
|
||||
|
||||
def _count_hidden(nodes, genome):
|
||||
"""Count non-NaN hidden nodes (excluding inputs/outputs)."""
|
||||
return sum(
|
||||
1
|
||||
for n in nodes
|
||||
if not jnp.isnan(n[0])
|
||||
and int(n[0]) not in set(genome.input_idx.tolist())
|
||||
and int(n[0]) not in set(genome.output_idx.tolist())
|
||||
)
|
||||
|
||||
|
||||
def _count_valid_conns(conns):
|
||||
"""Count non-NaN connections."""
|
||||
return sum(1 for c in conns if not jnp.isnan(c[0]))
|
||||
|
||||
|
||||
def test_module_duplication_standalone():
|
||||
"""Test ModuleDuplication in isolation (not via CombinedMutation).
|
||||
|
||||
Start with a genome that has one hidden node (via init_hidden_layers),
|
||||
then apply only ModuleDuplication with rate=1.0.
|
||||
"""
|
||||
genome = DefaultGenome(
|
||||
num_inputs=2,
|
||||
num_outputs=1,
|
||||
max_nodes=10,
|
||||
max_conns=30,
|
||||
init_hidden_layers=(1,), # start with 1 hidden node
|
||||
)
|
||||
state = genome.setup()
|
||||
randkey = jax.random.PRNGKey(42)
|
||||
nodes, conns = genome.initialize(state, randkey)
|
||||
|
||||
init_hidden = _count_hidden(nodes, genome)
|
||||
init_conns = _count_valid_conns(conns)
|
||||
assert init_hidden == 1, f"Expected 1 initial hidden node, got {init_hidden}"
|
||||
|
||||
# Apply only the duplication operator (not CombinedMutation).
|
||||
dup = ModuleDuplication(duplication_rate=1.0)
|
||||
k1, _ = jax.random.split(randkey)
|
||||
new_node_key = jnp.array(genome.max_nodes + 1)
|
||||
new_conn_key = jnp.array(
|
||||
[genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3]
|
||||
)
|
||||
nodes_mut, conns_mut = dup(
|
||||
state, genome, k1, nodes, conns, new_node_key, new_conn_key
|
||||
)
|
||||
|
||||
post_hidden = _count_hidden(nodes_mut, genome)
|
||||
post_conns = _count_valid_conns(conns_mut)
|
||||
|
||||
assert post_hidden == init_hidden + 1, (
|
||||
f"Module duplication should add exactly one hidden node, "
|
||||
f"got {post_hidden} (was {init_hidden})"
|
||||
)
|
||||
# The source node had 2 incoming (from inputs) + 1 outgoing (to output) = 3 conns.
|
||||
# Duplication should copy all of them.
|
||||
assert post_conns > init_conns, (
|
||||
f"Duplication should add connections, got {post_conns} (was {init_conns})"
|
||||
)
|
||||
|
||||
|
||||
def test_duplicated_node_has_connections():
|
||||
"""The duplicated node must have both incoming and outgoing connections."""
|
||||
genome = DefaultGenome(
|
||||
num_inputs=2,
|
||||
num_outputs=1,
|
||||
max_nodes=10,
|
||||
max_conns=30,
|
||||
init_hidden_layers=(1,),
|
||||
)
|
||||
state = genome.setup()
|
||||
randkey = jax.random.PRNGKey(7)
|
||||
nodes, conns = genome.initialize(state, randkey)
|
||||
|
||||
dup = ModuleDuplication(duplication_rate=1.0)
|
||||
k1, _ = jax.random.split(randkey)
|
||||
new_node_key = jnp.array(genome.max_nodes + 1)
|
||||
new_conn_key = jnp.array(
|
||||
[genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3]
|
||||
)
|
||||
nodes_mut, conns_mut = dup(
|
||||
state, genome, k1, nodes, conns, new_node_key, new_conn_key
|
||||
)
|
||||
|
||||
nk = float(new_node_key)
|
||||
incoming = sum(
|
||||
1 for c in conns_mut if not jnp.isnan(c[0]) and float(c[1]) == nk
|
||||
)
|
||||
outgoing = sum(
|
||||
1 for c in conns_mut if not jnp.isnan(c[0]) and float(c[0]) == nk
|
||||
)
|
||||
assert incoming > 0, "Duplicated node must have incoming connections"
|
||||
assert outgoing > 0, "Duplicated node must have outgoing connections"
|
||||
|
||||
|
||||
def test_combined_mutation_with_duplication():
|
||||
"""CombinedMutation should produce a valid genome (no crashes)."""
|
||||
genome = DefaultGenome(
|
||||
num_inputs=2,
|
||||
num_outputs=1,
|
||||
max_nodes=15,
|
||||
max_conns=40,
|
||||
duplication_rate=1.0,
|
||||
)
|
||||
state = genome.setup()
|
||||
randkey = jax.random.PRNGKey(99)
|
||||
nodes, conns = genome.initialize(state, randkey)
|
||||
|
||||
k1, _ = jax.random.split(randkey)
|
||||
new_node_key = jnp.array(genome.max_nodes + 1)
|
||||
new_conn_key = jnp.array(
|
||||
[genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3]
|
||||
)
|
||||
nodes_mut, conns_mut = genome.execute_mutation(
|
||||
state, k1, nodes, conns, new_node_key, new_conn_key
|
||||
)
|
||||
|
||||
# Should not crash and should have at least as many nodes/conns as before.
|
||||
init_valid = sum(1 for n in nodes if not jnp.isnan(n[0]))
|
||||
post_valid = sum(1 for n in nodes_mut if not jnp.isnan(n[0]))
|
||||
assert post_valid >= init_valid, "Mutation should not lose nodes"
|
||||
|
||||
|
||||
def test_duplication_disabled_when_rate_zero():
|
||||
"""With duplication_rate=0.0, no duplication should occur."""
|
||||
genome = DefaultGenome(
|
||||
num_inputs=2,
|
||||
num_outputs=1,
|
||||
max_nodes=10,
|
||||
max_conns=30,
|
||||
init_hidden_layers=(1,),
|
||||
)
|
||||
state = genome.setup()
|
||||
randkey = jax.random.PRNGKey(0)
|
||||
nodes, conns = genome.initialize(state, randkey)
|
||||
|
||||
dup = ModuleDuplication(duplication_rate=0.0)
|
||||
k1, _ = jax.random.split(randkey)
|
||||
new_node_key = jnp.array(genome.max_nodes + 1)
|
||||
new_conn_key = jnp.array(
|
||||
[genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3]
|
||||
)
|
||||
nodes_mut, conns_mut = dup(
|
||||
state, genome, k1, nodes, conns, new_node_key, new_conn_key
|
||||
)
|
||||
|
||||
assert _count_hidden(nodes_mut, genome) == _count_hidden(nodes, genome), (
|
||||
"No duplication should occur with rate=0.0"
|
||||
)
|
||||
Reference in New Issue
Block a user