Add MEND module duplication operator
Implement the core MEND mechanism: a module duplication mutation that copies a hidden node with all its incoming and outgoing connections. This is the single addition MEND makes to standard NEAT. - ModuleDuplication: JAX-compatible operator using jax.lax.scan - CombinedMutation: composes DefaultMutation + ModuleDuplication - DefaultGenome: accepts duplication_rate parameter - Tests for standalone duplication, combined mutation, and rate=0 no-op Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -7,7 +7,12 @@ import sympy as sp
|
|||||||
|
|
||||||
from .base import BaseGenome
|
from .base import BaseGenome
|
||||||
from .gene import DefaultNode, DefaultConn
|
from .gene import DefaultNode, DefaultConn
|
||||||
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
|
from .operations import (
|
||||||
|
DefaultMutation,
|
||||||
|
DefaultCrossover,
|
||||||
|
DefaultDistance,
|
||||||
|
CombinedMutation,
|
||||||
|
)
|
||||||
from .utils import unflatten_conns, extract_gene_attrs, extract_gene_attrs
|
from .utils import unflatten_conns, extract_gene_attrs, extract_gene_attrs
|
||||||
|
|
||||||
from tensorneat.common import (
|
from tensorneat.common import (
|
||||||
@@ -40,8 +45,12 @@ class DefaultGenome(BaseGenome):
|
|||||||
output_transform=None,
|
output_transform=None,
|
||||||
input_transform=None,
|
input_transform=None,
|
||||||
init_hidden_layers=(),
|
init_hidden_layers=(),
|
||||||
|
duplication_rate: float = 0.0,
|
||||||
):
|
):
|
||||||
|
# If a duplication_rate is specified, wrap the default mutation with the CombinedMutation
|
||||||
|
if duplication_rate > 0.0:
|
||||||
|
# Use CombinedMutation which includes standard mutation plus module duplication
|
||||||
|
mutation = CombinedMutation(duplication_rate=duplication_rate)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_inputs,
|
num_inputs,
|
||||||
num_outputs,
|
num_outputs,
|
||||||
@@ -66,7 +75,6 @@ class DefaultGenome(BaseGenome):
|
|||||||
return seqs, nodes, conns, u_conns
|
return seqs, nodes, conns, u_conns
|
||||||
|
|
||||||
def forward(self, state, transformed, inputs):
|
def forward(self, state, transformed, inputs):
|
||||||
|
|
||||||
if self.input_transform is not None:
|
if self.input_transform is not None:
|
||||||
inputs = self.input_transform(inputs)
|
inputs = self.input_transform(inputs)
|
||||||
|
|
||||||
@@ -133,9 +141,7 @@ class DefaultGenome(BaseGenome):
|
|||||||
network["topo_order"] = topo_order
|
network["topo_order"] = topo_order
|
||||||
network["topo_layers"] = topo_layers
|
network["topo_layers"] = topo_layers
|
||||||
network["useful_nodes"] = find_useful_nodes(
|
network["useful_nodes"] = find_useful_nodes(
|
||||||
set(network["nodes"]),
|
set(network["nodes"]), set(network["conns"]), set(self.output_idx)
|
||||||
set(network["conns"]),
|
|
||||||
set(self.output_idx)
|
|
||||||
)
|
)
|
||||||
return network
|
return network
|
||||||
|
|
||||||
@@ -147,7 +153,6 @@ class DefaultGenome(BaseGenome):
|
|||||||
sympy_output_transform=None,
|
sympy_output_transform=None,
|
||||||
backend="jax",
|
backend="jax",
|
||||||
):
|
):
|
||||||
|
|
||||||
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
|
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
|
||||||
|
|
||||||
if sympy_input_transform is None and self.input_transform is not None:
|
if sympy_input_transform is None and self.input_transform is not None:
|
||||||
@@ -187,7 +192,6 @@ class DefaultGenome(BaseGenome):
|
|||||||
nodes_exprs = {}
|
nodes_exprs = {}
|
||||||
args_symbols = {}
|
args_symbols = {}
|
||||||
for i in order:
|
for i in order:
|
||||||
|
|
||||||
if i in input_idx:
|
if i in input_idx:
|
||||||
nodes_exprs[symbols[-i - 1]] = symbols[
|
nodes_exprs[symbols[-i - 1]] = symbols[
|
||||||
-i - 1
|
-i - 1
|
||||||
@@ -258,7 +262,7 @@ class DefaultGenome(BaseGenome):
|
|||||||
output_exprs,
|
output_exprs,
|
||||||
forward_func,
|
forward_func,
|
||||||
network["topo_order"],
|
network["topo_order"],
|
||||||
network["useful_nodes"]
|
network["useful_nodes"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def visualize(
|
def visualize(
|
||||||
@@ -395,7 +399,7 @@ class DefaultGenome(BaseGenome):
|
|||||||
edgecolors=edgecolors,
|
edgecolors=edgecolors,
|
||||||
arrowstyle=arrowstyle,
|
arrowstyle=arrowstyle,
|
||||||
arrowsize=arrowsize,
|
arrowsize=arrowsize,
|
||||||
edge_color=edge_color
|
edge_color=edge_color,
|
||||||
)
|
)
|
||||||
plt.savefig(save_path, dpi=save_dpi)
|
plt.savefig(save_path, dpi=save_dpi)
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
from .crossover import BaseCrossover, DefaultCrossover
|
from .crossover import BaseCrossover, DefaultCrossover
|
||||||
from .mutation import BaseMutation, DefaultMutation
|
from .mutation import BaseMutation, DefaultMutation, CombinedMutation, ModuleDuplication
|
||||||
from .distance import BaseDistance, DefaultDistance
|
from .distance import BaseDistance, DefaultDistance
|
||||||
|
|||||||
@@ -1,2 +1,3 @@
|
|||||||
from .base import BaseMutation
|
from .base import BaseMutation
|
||||||
from .default import DefaultMutation
|
from .default import DefaultMutation
|
||||||
|
from .module_duplication import CombinedMutation, ModuleDuplication
|
||||||
|
|||||||
141
src/tensorneat/genome/operations/mutation/module_duplication.py
Normal file
141
src/tensorneat/genome/operations/mutation/module_duplication.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""Module Duplication Mutation
|
||||||
|
|
||||||
|
Implements the MEND *module duplication* operator. It duplicates a hidden
|
||||||
|
module (a hidden node together with all its incoming and outgoing connections).
|
||||||
|
The operator fires with probability ``duplication_rate``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from .base import BaseMutation
|
||||||
|
from .default import DefaultMutation
|
||||||
|
from ...utils import add_node, add_conn, extract_gene_attrs
|
||||||
|
from tensorneat.common import I_INF
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleDuplication(BaseMutation):
|
||||||
|
"""Duplicate a hidden module (node + its connections).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
duplication_rate: float, optional
|
||||||
|
Probability that the duplication operator fires during a mutation
|
||||||
|
step. Must be in ``[0, 1]``. Default ``0.0`` (disabled).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, duplication_rate: float = 0.0):
|
||||||
|
self.duplication_rate = duplication_rate
|
||||||
|
self._helper = DefaultMutation()
|
||||||
|
|
||||||
|
def _choose_node_key(self, key, nodes, input_idx, output_idx):
|
||||||
|
return self._helper.choose_node_key(
|
||||||
|
key,
|
||||||
|
nodes,
|
||||||
|
input_idx,
|
||||||
|
output_idx,
|
||||||
|
allow_input_keys=False,
|
||||||
|
allow_output_keys=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, state, genome, randkey, nodes, conns, new_node_key, new_conn_key
|
||||||
|
):
|
||||||
|
"""Apply module duplication.
|
||||||
|
|
||||||
|
If the random draw is lower than ``duplication_rate`` a hidden node is
|
||||||
|
selected and duplicated (its connections are copied). Otherwise the genome
|
||||||
|
is returned unchanged.
|
||||||
|
"""
|
||||||
|
fire = jax.random.uniform(randkey) < self.duplication_rate
|
||||||
|
|
||||||
|
n_fixed = len(genome.conn_gene.fixed_attrs)
|
||||||
|
|
||||||
|
def duplicate(_):
|
||||||
|
node_key, node_idx = self._choose_node_key(
|
||||||
|
randkey, nodes, genome.input_idx, genome.output_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
def no_node():
|
||||||
|
return nodes, conns
|
||||||
|
|
||||||
|
def do_dup():
|
||||||
|
# Add a new hidden node with identity attributes.
|
||||||
|
new_nodes = add_node(
|
||||||
|
nodes,
|
||||||
|
jnp.array([new_node_key]),
|
||||||
|
genome.node_gene.new_identity_attrs(state),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use jax.lax.scan to iterate over all connections and
|
||||||
|
# conditionally copy incoming ones (target == node_key).
|
||||||
|
def copy_incoming(carry, conn):
|
||||||
|
conns_acc = carry
|
||||||
|
is_incoming = (conn[1] == node_key) & ~jnp.isnan(conn[0])
|
||||||
|
attrs = conn[n_fixed:]
|
||||||
|
src = conn[0]
|
||||||
|
fix = jnp.array([src, new_node_key])
|
||||||
|
new_conn = jnp.concatenate((fix, attrs))
|
||||||
|
pos = jnp.argmax(jnp.isnan(conns_acc[:, 0]))
|
||||||
|
has_space = jnp.isnan(conns_acc[pos, 0])
|
||||||
|
should_add = is_incoming & has_space
|
||||||
|
conns_acc = jnp.where(
|
||||||
|
should_add,
|
||||||
|
conns_acc.at[pos].set(new_conn),
|
||||||
|
conns_acc,
|
||||||
|
)
|
||||||
|
return conns_acc, None
|
||||||
|
|
||||||
|
new_conns, _ = jax.lax.scan(copy_incoming, conns, conns)
|
||||||
|
|
||||||
|
# Copy outgoing connections (source == node_key).
|
||||||
|
def copy_outgoing(carry, conn):
|
||||||
|
conns_acc = carry
|
||||||
|
is_outgoing = (conn[0] == node_key) & ~jnp.isnan(conn[0])
|
||||||
|
attrs = conn[n_fixed:]
|
||||||
|
tgt = conn[1]
|
||||||
|
fix = jnp.array([new_node_key, tgt])
|
||||||
|
new_conn = jnp.concatenate((fix, attrs))
|
||||||
|
pos = jnp.argmax(jnp.isnan(conns_acc[:, 0]))
|
||||||
|
has_space = jnp.isnan(conns_acc[pos, 0])
|
||||||
|
should_add = is_outgoing & has_space
|
||||||
|
conns_acc = jnp.where(
|
||||||
|
should_add,
|
||||||
|
conns_acc.at[pos].set(new_conn),
|
||||||
|
conns_acc,
|
||||||
|
)
|
||||||
|
return conns_acc, None
|
||||||
|
|
||||||
|
# Scan over the ORIGINAL conns for outgoing check, but
|
||||||
|
# accumulate into the already-updated new_conns.
|
||||||
|
new_conns, _ = jax.lax.scan(copy_outgoing, new_conns, conns)
|
||||||
|
|
||||||
|
return new_nodes, new_conns
|
||||||
|
|
||||||
|
return jax.lax.cond(node_idx == I_INF, no_node, do_dup)
|
||||||
|
|
||||||
|
return jax.lax.cond(fire, duplicate, lambda _: (nodes, conns), operand=None)
|
||||||
|
|
||||||
|
|
||||||
|
class CombinedMutation(BaseMutation):
|
||||||
|
"""Combine ``DefaultMutation`` with optional ``ModuleDuplication``.
|
||||||
|
|
||||||
|
The combined mutation first runs the standard NEAT structural/value
|
||||||
|
mutations and then (with the configured ``duplication_rate``) may duplicate a
|
||||||
|
module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, duplication_rate: float = 0.0, **default_kwargs):
|
||||||
|
self.default_mut = DefaultMutation(**default_kwargs)
|
||||||
|
self.dup_mut = ModuleDuplication(duplication_rate)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, state, genome, randkey, nodes, conns, new_node_key, new_conn_key
|
||||||
|
):
|
||||||
|
k1, k2 = jax.random.split(randkey)
|
||||||
|
nodes, conns = self.default_mut(
|
||||||
|
state, genome, k1, nodes, conns, new_node_key, new_conn_key
|
||||||
|
)
|
||||||
|
nodes, conns = self.dup_mut(
|
||||||
|
state, genome, k2, nodes, conns, new_node_key, new_conn_key
|
||||||
|
)
|
||||||
|
return nodes, conns
|
||||||
164
test/test_module_duplication.py
Normal file
164
test/test_module_duplication.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
"""Tests for the MEND ModuleDuplication mutation.
|
||||||
|
|
||||||
|
Tests verify that:
|
||||||
|
1. The ModuleDuplication operator alone adds exactly one hidden node.
|
||||||
|
2. The duplicated node has connections copied from the source.
|
||||||
|
3. CombinedMutation integrates both default and duplication mutations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from tensorneat.genome import DefaultGenome
|
||||||
|
from tensorneat.genome.operations.mutation import ModuleDuplication
|
||||||
|
|
||||||
|
|
||||||
|
def _count_hidden(nodes, genome):
|
||||||
|
"""Count non-NaN hidden nodes (excluding inputs/outputs)."""
|
||||||
|
return sum(
|
||||||
|
1
|
||||||
|
for n in nodes
|
||||||
|
if not jnp.isnan(n[0])
|
||||||
|
and int(n[0]) not in set(genome.input_idx.tolist())
|
||||||
|
and int(n[0]) not in set(genome.output_idx.tolist())
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_valid_conns(conns):
|
||||||
|
"""Count non-NaN connections."""
|
||||||
|
return sum(1 for c in conns if not jnp.isnan(c[0]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_module_duplication_standalone():
|
||||||
|
"""Test ModuleDuplication in isolation (not via CombinedMutation).
|
||||||
|
|
||||||
|
Start with a genome that has one hidden node (via init_hidden_layers),
|
||||||
|
then apply only ModuleDuplication with rate=1.0.
|
||||||
|
"""
|
||||||
|
genome = DefaultGenome(
|
||||||
|
num_inputs=2,
|
||||||
|
num_outputs=1,
|
||||||
|
max_nodes=10,
|
||||||
|
max_conns=30,
|
||||||
|
init_hidden_layers=(1,), # start with 1 hidden node
|
||||||
|
)
|
||||||
|
state = genome.setup()
|
||||||
|
randkey = jax.random.PRNGKey(42)
|
||||||
|
nodes, conns = genome.initialize(state, randkey)
|
||||||
|
|
||||||
|
init_hidden = _count_hidden(nodes, genome)
|
||||||
|
init_conns = _count_valid_conns(conns)
|
||||||
|
assert init_hidden == 1, f"Expected 1 initial hidden node, got {init_hidden}"
|
||||||
|
|
||||||
|
# Apply only the duplication operator (not CombinedMutation).
|
||||||
|
dup = ModuleDuplication(duplication_rate=1.0)
|
||||||
|
k1, _ = jax.random.split(randkey)
|
||||||
|
new_node_key = jnp.array(genome.max_nodes + 1)
|
||||||
|
new_conn_key = jnp.array(
|
||||||
|
[genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3]
|
||||||
|
)
|
||||||
|
nodes_mut, conns_mut = dup(
|
||||||
|
state, genome, k1, nodes, conns, new_node_key, new_conn_key
|
||||||
|
)
|
||||||
|
|
||||||
|
post_hidden = _count_hidden(nodes_mut, genome)
|
||||||
|
post_conns = _count_valid_conns(conns_mut)
|
||||||
|
|
||||||
|
assert post_hidden == init_hidden + 1, (
|
||||||
|
f"Module duplication should add exactly one hidden node, "
|
||||||
|
f"got {post_hidden} (was {init_hidden})"
|
||||||
|
)
|
||||||
|
# The source node had 2 incoming (from inputs) + 1 outgoing (to output) = 3 conns.
|
||||||
|
# Duplication should copy all of them.
|
||||||
|
assert post_conns > init_conns, (
|
||||||
|
f"Duplication should add connections, got {post_conns} (was {init_conns})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_duplicated_node_has_connections():
|
||||||
|
"""The duplicated node must have both incoming and outgoing connections."""
|
||||||
|
genome = DefaultGenome(
|
||||||
|
num_inputs=2,
|
||||||
|
num_outputs=1,
|
||||||
|
max_nodes=10,
|
||||||
|
max_conns=30,
|
||||||
|
init_hidden_layers=(1,),
|
||||||
|
)
|
||||||
|
state = genome.setup()
|
||||||
|
randkey = jax.random.PRNGKey(7)
|
||||||
|
nodes, conns = genome.initialize(state, randkey)
|
||||||
|
|
||||||
|
dup = ModuleDuplication(duplication_rate=1.0)
|
||||||
|
k1, _ = jax.random.split(randkey)
|
||||||
|
new_node_key = jnp.array(genome.max_nodes + 1)
|
||||||
|
new_conn_key = jnp.array(
|
||||||
|
[genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3]
|
||||||
|
)
|
||||||
|
nodes_mut, conns_mut = dup(
|
||||||
|
state, genome, k1, nodes, conns, new_node_key, new_conn_key
|
||||||
|
)
|
||||||
|
|
||||||
|
nk = float(new_node_key)
|
||||||
|
incoming = sum(
|
||||||
|
1 for c in conns_mut if not jnp.isnan(c[0]) and float(c[1]) == nk
|
||||||
|
)
|
||||||
|
outgoing = sum(
|
||||||
|
1 for c in conns_mut if not jnp.isnan(c[0]) and float(c[0]) == nk
|
||||||
|
)
|
||||||
|
assert incoming > 0, "Duplicated node must have incoming connections"
|
||||||
|
assert outgoing > 0, "Duplicated node must have outgoing connections"
|
||||||
|
|
||||||
|
|
||||||
|
def test_combined_mutation_with_duplication():
|
||||||
|
"""CombinedMutation should produce a valid genome (no crashes)."""
|
||||||
|
genome = DefaultGenome(
|
||||||
|
num_inputs=2,
|
||||||
|
num_outputs=1,
|
||||||
|
max_nodes=15,
|
||||||
|
max_conns=40,
|
||||||
|
duplication_rate=1.0,
|
||||||
|
)
|
||||||
|
state = genome.setup()
|
||||||
|
randkey = jax.random.PRNGKey(99)
|
||||||
|
nodes, conns = genome.initialize(state, randkey)
|
||||||
|
|
||||||
|
k1, _ = jax.random.split(randkey)
|
||||||
|
new_node_key = jnp.array(genome.max_nodes + 1)
|
||||||
|
new_conn_key = jnp.array(
|
||||||
|
[genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3]
|
||||||
|
)
|
||||||
|
nodes_mut, conns_mut = genome.execute_mutation(
|
||||||
|
state, k1, nodes, conns, new_node_key, new_conn_key
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not crash and should have at least as many nodes/conns as before.
|
||||||
|
init_valid = sum(1 for n in nodes if not jnp.isnan(n[0]))
|
||||||
|
post_valid = sum(1 for n in nodes_mut if not jnp.isnan(n[0]))
|
||||||
|
assert post_valid >= init_valid, "Mutation should not lose nodes"
|
||||||
|
|
||||||
|
|
||||||
|
def test_duplication_disabled_when_rate_zero():
|
||||||
|
"""With duplication_rate=0.0, no duplication should occur."""
|
||||||
|
genome = DefaultGenome(
|
||||||
|
num_inputs=2,
|
||||||
|
num_outputs=1,
|
||||||
|
max_nodes=10,
|
||||||
|
max_conns=30,
|
||||||
|
init_hidden_layers=(1,),
|
||||||
|
)
|
||||||
|
state = genome.setup()
|
||||||
|
randkey = jax.random.PRNGKey(0)
|
||||||
|
nodes, conns = genome.initialize(state, randkey)
|
||||||
|
|
||||||
|
dup = ModuleDuplication(duplication_rate=0.0)
|
||||||
|
k1, _ = jax.random.split(randkey)
|
||||||
|
new_node_key = jnp.array(genome.max_nodes + 1)
|
||||||
|
new_conn_key = jnp.array(
|
||||||
|
[genome.max_conns + 1, genome.max_conns + 2, genome.max_conns + 3]
|
||||||
|
)
|
||||||
|
nodes_mut, conns_mut = dup(
|
||||||
|
state, genome, k1, nodes, conns, new_node_key, new_conn_key
|
||||||
|
)
|
||||||
|
|
||||||
|
assert _count_hidden(nodes_mut, genome) == _count_hidden(nodes, genome), (
|
||||||
|
"No duplication should occur with rate=0.0"
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user