Files
tensorneat-mend/tensorneat/genome/operations/mutation/default.py
2024-07-10 16:40:03 +08:00

293 lines
9.4 KiB
Python

import jax
from jax import vmap, numpy as jnp
from . import BaseMutation
from tensorneat.common import (
fetch_first,
fetch_random,
I_INF,
check_cycles,
)
from ...utils import (
unflatten_conns,
add_node,
add_conn,
delete_node_by_pos,
delete_conn_by_pos,
extract_node_attrs,
extract_conn_attrs,
set_node_attrs,
set_conn_attrs,
)
class DefaultMutation(BaseMutation):
def __init__(
self,
conn_add: float = 0.2,
conn_delete: float = 0,
node_add: float = 0.2,
node_delete: float = 0,
):
self.conn_add = conn_add
self.conn_delete = conn_delete
self.node_add = node_add
self.node_delete = node_delete
def __call__(self, state, randkey, genome, nodes, conns, new_node_key):
k1, k2 = jax.random.split(randkey)
nodes, conns = self.mutate_structure(
state, k1, genome, nodes, conns, new_node_key
)
nodes, conns = self.mutate_values(state, k2, genome, nodes, conns)
return nodes, conns
def mutate_structure(self, state, randkey, genome, nodes, conns, new_node_key):
def mutate_add_node(key_, nodes_, conns_):
"""
add a node while do not influence the output of the network
"""
remain_node_space = jnp.isnan(nodes_[:, 0]).sum()
remain_conn_space = jnp.isnan(conns_[:, 0]).sum()
i_key, o_key, idx = self.choose_connection_key(
key_, conns_
) # choose a connection
def successful_add_node():
# remove the original connection and record its attrs
original_attrs = extract_conn_attrs(conns_[idx])
new_conns = delete_conn_by_pos(conns_, idx)
# add a new node with identity attrs
new_nodes = add_node(
nodes_, new_node_key, genome.node_gene.new_identity_attrs(state)
)
# add two new connections
# first is with identity attrs
new_conns = add_conn(
new_conns,
i_key,
new_node_key,
genome.conn_gene.new_identity_attrs(state),
)
# second is with the origin attrs
new_conns = add_conn(
new_conns,
new_node_key,
o_key,
original_attrs,
)
return new_nodes, new_conns
return jax.lax.cond(
(idx == I_INF) | (remain_node_space < 1) | (remain_conn_space < 2),
lambda: (nodes_, conns_), # do nothing
successful_add_node,
)
def mutate_delete_node(key_, nodes_, conns_):
"""
delete a node
"""
# randomly choose a node
key, idx = self.choose_node_key(
key_,
nodes_,
genome.input_idx,
genome.output_idx,
allow_input_keys=False,
allow_output_keys=False,
)
def successful_delete_node():
# delete the node
new_nodes = delete_node_by_pos(nodes_, idx)
# delete all connections
new_conns = jnp.where(
((conns_[:, 0] == key) | (conns_[:, 1] == key))[:, None],
jnp.nan,
conns_,
)
return new_nodes, new_conns
return jax.lax.cond(
idx == I_INF, # no available node to delete
lambda: (nodes_, conns_), # do nothing
successful_delete_node,
)
def mutate_add_conn(key_, nodes_, conns_):
"""
add a connection while do not influence the output of the network
"""
remain_conn_space = jnp.isnan(conns_[:, 0]).sum()
# randomly choose two nodes
k1_, k2_ = jax.random.split(key_, num=2)
# input node of the connection can be any node
i_key, from_idx = self.choose_node_key(
k1_,
nodes_,
genome.input_idx,
genome.output_idx,
allow_input_keys=True,
allow_output_keys=True,
)
# output node of the connection can be any node except input node
o_key, to_idx = self.choose_node_key(
k2_,
nodes_,
genome.input_idx,
genome.output_idx,
allow_input_keys=False,
allow_output_keys=True,
)
conn_pos = fetch_first((conns_[:, 0] == i_key) & (conns_[:, 1] == o_key))
is_already_exist = conn_pos != I_INF
def nothing():
return nodes_, conns_
def successful():
# add a connection with zero attrs
return nodes_, add_conn(
conns_, i_key, o_key, genome.conn_gene.new_zero_attrs(state)
)
if genome.network_type == "feedforward":
u_conns = unflatten_conns(nodes_, conns_)
conns_exist = u_conns != I_INF
is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx)
return jax.lax.cond(
is_already_exist | is_cycle | (remain_conn_space < 1),
nothing,
successful,
)
elif genome.network_type == "recurrent":
return jax.lax.cond(
is_already_exist | (remain_conn_space < 1),
nothing,
successful,
)
else:
raise ValueError(f"Invalid network type: {genome.network_type}")
def mutate_delete_conn(key_, nodes_, conns_):
# randomly choose a connection
i_key, o_key, idx = self.choose_connection_key(key_, conns_)
return jax.lax.cond(
idx == I_INF,
lambda: (nodes_, conns_), # nothing
lambda: (nodes_, delete_conn_by_pos(conns_, idx)), # success
)
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
def nothing(_, nodes_, conns_):
return nodes_, conns_
if self.node_add > 0:
nodes, conns = jax.lax.cond(
r1 < self.node_add, mutate_add_node, nothing, k1, nodes, conns
)
if self.node_delete > 0:
nodes, conns = jax.lax.cond(
r2 < self.node_delete, mutate_delete_node, nothing, k2, nodes, conns
)
if self.conn_add > 0:
nodes, conns = jax.lax.cond(
r3 < self.conn_add, mutate_add_conn, nothing, k3, nodes, conns
)
if self.conn_delete > 0:
nodes, conns = jax.lax.cond(
r4 < self.conn_delete, mutate_delete_conn, nothing, k4, nodes, conns
)
return nodes, conns
def mutate_values(self, state, randkey, genome, nodes, conns):
k1, k2 = jax.random.split(randkey)
nodes_randkeys = jax.random.split(k1, num=genome.max_nodes)
conns_randkeys = jax.random.split(k2, num=genome.max_conns)
node_attrs = vmap(extract_node_attrs)(nodes)
new_node_attrs = vmap(genome.node_gene.mutate, in_axes=(None, 0, 0))(
state, nodes_randkeys, node_attrs
)
new_nodes = vmap(set_node_attrs)(nodes, new_node_attrs)
conn_attrs = vmap(extract_conn_attrs)(conns)
new_conn_attrs = vmap(genome.conn_gene.mutate, in_axes=(None, 0, 0))(
state, conns_randkeys, conn_attrs
)
new_conns = vmap(set_conn_attrs)(conns, new_conn_attrs)
# nan nodes not changed
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)
new_conns = jnp.where(jnp.isnan(conns), jnp.nan, new_conns)
return new_nodes, new_conns
def choose_node_key(
self,
key,
nodes,
input_idx,
output_idx,
allow_input_keys: bool = False,
allow_output_keys: bool = False,
):
"""
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
:param key:
:param nodes:
:param input_idx:
:param output_idx:
:param allow_input_keys:
:param allow_output_keys:
:return: return its key and position(idx)
"""
node_keys = nodes[:, 0]
mask = ~jnp.isnan(node_keys)
if not allow_input_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_idx))
if not allow_output_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_idx))
idx = fetch_random(key, mask)
key = jnp.where(idx != I_INF, nodes[idx, 0], jnp.nan)
return key, idx
def choose_connection_key(self, key, conns):
"""
Randomly choose a connection key from the given connections.
:return: i_key, o_key, idx
"""
idx = fetch_random(key, ~jnp.isnan(conns[:, 0]))
i_key = jnp.where(idx != I_INF, conns[idx, 0], jnp.nan)
o_key = jnp.where(idx != I_INF, conns[idx, 1], jnp.nan)
return i_key, o_key, idx