Add RecurrentMutation class for more control over recurrent mutations

Extends DefaultMutation with a specialized implementation that:
- Adds a bias parameter (p_recur) for controling the likelihood of
  a new connection mutation to be a recurrent connection
- Implements an vectorized algorithm for connection addition with retry
  logic
- Preserves all other mutation operations from the base implementation
- Maintains API compatibility with DefaultMutation
This commit is contained in:
massy
2025-04-22 10:35:47 -04:00
parent 96923be0b5
commit 7b2ae5c8a2

View File

@@ -0,0 +1,298 @@
# recurrent_mutation.py
import jax
import jax.numpy as jnp
# same helpers as DefaultMutation
from tensorneat.common import fetch_first, I_INF, check_cycles, fetch_random
from tensorneat.genome.operations.mutation.default import DefaultMutation
from tensorneat.genome.utils import (
extract_gene_attrs,
add_conn,
add_node,
unflatten_conns,
delete_conn_by_pos,
delete_node_by_pos,
set_gene_attrs,
)
class RecurrentMutation(DefaultMutation):
"""
A DefaultMutation variant that *biases* (not forces) evolution toward
recurrent connections, controlled by the parameter "p_recur".
Applies add/delete node and delete connection mutations identically to
the user's original implementation provided, but uses an optimized
add-connection strategy with cycle bias and retries.
Parameters
----------
p_recur : float, default 0.1
Probability (0 to 1) that any single *successful* add-connection
mutation *must* result in forming a cycle in the graph.
max_conn_tries : int, default 20
Maximum number of distinct (source, target) node pairs to sample
when attempting the add-connection mutation before giving up for
this genome in this generation.
conn_add : float
Probability of attempting connection addition. Inherited from
DefaultMutation.
conn_delete : float
Probability of attempting connection deletion. Inherited from
DefaultMutation.
node_add : float
Probability of attempting node addition. Inherited from DefaultMutation.
node_delete : float
Probability of attempting node deletion. Inherited from DefaultMutation.
kwargs : Other arguments accepted by DefaultMutation.__init__
"""
# constructor
def __init__(
self,
*,
p_recur: float = 0.1,
max_conn_tries: int = 20,
**kwargs # Pass conn_add, conn_delete etc. through to base class
):
# Initialize base class probabilities from DefaultMutation
super().__init__(**kwargs)
if not 0.0 <= p_recur <= 1.0:
raise ValueError("p_recur must be in [0, 1]")
if max_conn_tries < 1:
raise ValueError("max_conn_tries must be >= 1")
self.p_recur = float(p_recur)
self.max_conn_tries = int(max_conn_tries)
# structural mutation (all valuemutation logic is inherited intact)
def mutate_structure(
self, state, genome, randkey, nodes, conns, new_node_key, new_conn_key
):
"""
Apply ONE structural mutation (node-add/del, conn-add/del) chosen
according to the probabilities stored in this object.
Uses the custom `_mutate_add_conn_recurrent_optimized`.
Uses user's original implementations for add_node, delete_node,
delete_conn.
"""
# Helper functions (kept as per user request, structure preserved
def mutate_add_node(key_, nodes_, conns_):
remain_node_space = jnp.isnan(nodes_[:, 0]).sum()
remain_conn_space = jnp.isnan(conns_[:, 0]).sum()
i_key, o_key, idx = self.choose_connection_key(key_, conns_)
def successful_add_node():
# remove the original connection and record its attrs
original_attrs = extract_gene_attrs(
genome.conn_gene, conns_[idx]
)
new_conns = delete_conn_by_pos(conns_, idx)
# add a new node with identity attrs
new_nodes = add_node(
nodes_,
jnp.array([new_node_key]),
genome.node_gene.new_identity_attrs(state)
)
# build two replacement connections
if "historical_marker" in genome.conn_gene.fixed_attrs:
f1 = jnp.array([i_key, new_node_key, new_conn_key[0]])
f2 = jnp.array([new_node_key, o_key, new_conn_key[1]])
else:
f1 = jnp.array([i_key, new_node_key])
f2 = jnp.array([new_node_key, o_key])
new_conns = add_conn(
new_conns,
f1,
genome.conn_gene.new_identity_attrs(state)
)
new_conns = add_conn(new_conns, f2, original_attrs)
return new_nodes, new_conns
cond_do_nothing = (idx == I_INF) | \
(remain_node_space < 1) | \
(remain_conn_space < 2)
return jax.lax.cond(
cond_do_nothing, # condition for doing nothing
lambda: (nodes_, conns_), # do nothing branch (if cond is true)
successful_add_node # do add branch (if cond is false)
)
def mutate_delete_node(key_, nodes_, conns_):
k, idx = self.choose_node_key(
key_, nodes_, genome.input_idx, genome.output_idx,
allow_input_keys=False, allow_output_keys=False,
)
def do():
# delete the node
new_nodes = delete_node_by_pos(nodes_, idx)
# delete all connections
new_conns = jnp.where(
((conns_[:, 0] == k) | (conns_[:, 1] == k))[:, None],
jnp.nan, conns_
)
return new_nodes, new_conns
return jax.lax.cond(
idx == I_INF, # cond to determine "doing nothing"
lambda: (nodes_, conns_), # Do nothing branch
do # do delete branch
)
# NEW biased addconnection with optimizations and p_recur bias
def _mutate_add_conn_recurrent_optimized(key_, nodes_, conns_):
"""
Optimized: Attempts to insert one connection with recurrent bias.
Makes up to "self.max_conn_tries" draws.
Computes graph structure once before vmap.
Checks prerequisites before cycle check.
"""
remain_conn_space = jnp.isnan(conns_[:, 0]).sum()
space_available = remain_conn_space >= 1
# Optimization: Calculate graph structure ONCE before vmap
u_conns = unflatten_conns(nodes_, conns_)
conns_exist_matrix = u_conns != I_INF
# Function for a single attempt
def attempt(k_triplet):
"""One sampling attempt; returns (accept?, i_key, o_key)."""
k1_node, k2_node, k_recur = jax.random.split(k_triplet, 3)
# 1. Sample endpoints
i_key, from_idx = self.choose_node_key(
k1_node, nodes_, genome.input_idx, genome.output_idx,
allow_input_keys=True, allow_output_keys=True,
)
o_key, to_idx = self.choose_node_key(
k2_node, nodes_, genome.input_idx, genome.output_idx,
allow_input_keys=False, allow_output_keys=True,
)
# 2. Basic checks: nodes selected? space available?
nodes_chosen = (from_idx != I_INF) & (to_idx != I_INF)
prereqs_met = nodes_chosen & space_available
# --- Nested conditional logic for efficiency ---
def check_existence_and_cycle():
# 3. Duplicate check (only if prereqs met)
exists = fetch_first(
(conns_[:, 0] == i_key) & (conns_[:, 1] == o_key)
) != I_INF
not_duplicate = ~exists
def check_cycle_logic():
# 4. Cycle check (only if valid, non-duplicate candidate)
forms_cy = check_cycles(
nodes_,
conns_exist_matrix,
from_idx, to_idx
)
# 5. Decide if we *require* a cycle
force_recur = jax.random.uniform(k_recur) < self.p_recur
# 6. Final check: cycle requirement satisfaction
cycle_req_satisfied = (force_recur & forms_cy) | \
(~force_recur)
# Attempt is valid if cycle requirement is met
is_valid = cycle_req_satisfied
return is_valid, i_key, o_key
# If not duplicate, check cycle logic, else invalid
is_valid, final_i_key, final_o_key = jax.lax.cond(
not_duplicate,
check_cycle_logic,
lambda: (jnp.array(False), jnp.nan, jnp.nan) # is_valid=False if duplicate
)
return is_valid, final_i_key, final_o_key
# If basic prereqs met, check existence/cycle, else invalid
is_valid_attempt, i_key_attempt, o_key_attempt = jax.lax.cond(
prereqs_met,
check_existence_and_cycle,
lambda: (jnp.array(False), jnp.nan, jnp.nan) # is_valid=False if prereqs not met
)
return is_valid_attempt, i_key_attempt, o_key_attempt # bool, float, float
# Vectorize attempts and find first success
keys = jax.random.split(key_, self.max_conn_tries)
accept_flags, i_keys_all, o_keys_all = jax.vmap(attempt)(keys)
# find FIRST successful draw (if any)
first_success_idx = fetch_first(accept_flags)
found_valid_candidate = first_success_idx != I_INF
# Conditionally add the connection
def do_accept():
# Get the keys from the first successful attempt
i_key_chosen = i_keys_all[first_success_idx]
o_key_chosen = o_keys_all[first_success_idx]
# Create fixed attributes (using 3rd marker for add_conn)
if "historical_marker" in genome.conn_gene.fixed_attrs:
fix = jnp.array([i_key_chosen, o_key_chosen, new_conn_key[2]])
else:
fix = jnp.array([i_key_chosen, o_key_chosen])
# Add the connection
new_conns = add_conn(
conns_,
fix,
genome.conn_gene.new_zero_attrs(state) # Add with zero/default attrs
)
return nodes_, new_conns
# If a valid candidate was found, add it; otherwise return original
# arrays
return jax.lax.cond(
found_valid_candidate, # Condition is True if we should accept
do_accept, # Do accept branch (if cond is true)
lambda: (nodes_, conns_) # Do nothing branch (if cond is false)
)
def mutate_delete_conn(key_, nodes_, conns_):
i_key, o_key, idx = self.choose_connection_key(key_, conns_)
return jax.lax.cond(
idx == I_INF,
lambda: (nodes_, conns_),
lambda: (nodes_, delete_conn_by_pos(conns_, idx)),
)
# --- Scheduling Logic (unchanged from user's code) ---
k_node_add, k_node_del, k_conn_add, k_conn_del, k_schedule = \
jax.random.split(randkey, 5)
probs = jax.random.uniform(k_schedule, (4,))
def nothing(_, n_, c_): return n_, c_
# Apply mutations conditionally using original helper functions
nodes, conns = jax.lax.cond(
(self.node_add > 0) & (probs[0] < self.node_add),
mutate_add_node, nothing, k_node_add, nodes, conns
)
nodes, conns = jax.lax.cond(
(self.node_delete > 0) & (probs[1] < self.node_delete),
mutate_delete_node, nothing, k_node_del, nodes, conns
)
# Apply connection addition using the OPTIMIZED version
nodes, conns = jax.lax.cond(
(self.conn_add > 0) & (probs[2] < self.conn_add),
_mutate_add_conn_recurrent_optimized, # <<< Use optimized add_conn
nothing, k_conn_add, nodes, conns
)
# Apply connection deletion using original helper function
nodes, conns = jax.lax.cond(
(self.conn_delete > 0) & (probs[3] < self.conn_delete),
mutate_delete_conn, nothing, k_conn_del, nodes, conns
)
return nodes, conns