import jax from jax import vmap, numpy as jnp from . import BaseMutation from tensorneat.common import ( fetch_first, fetch_random, I_INF, check_cycles, ) from ...utils import ( unflatten_conns, add_node, add_conn, delete_node_by_pos, delete_conn_by_pos, extract_node_attrs, extract_conn_attrs, set_node_attrs, set_conn_attrs, ) class DefaultMutation(BaseMutation): def __init__( self, conn_add: float = 0.1, conn_delete: float = 0, node_add: float = 0.1, node_delete: float = 0, ): self.conn_add = conn_add self.conn_delete = conn_delete self.node_add = node_add self.node_delete = node_delete def __call__(self, state, randkey, nodes, conns, new_node_key): k1, k2 = jax.random.split(randkey) nodes, conns = self.mutate_structure( state, k1, nodes, conns, new_node_key ) nodes, conns = self.mutate_values(state, k2, nodes, conns) return nodes, conns def mutate_structure(self, state, randkey, nodes, conns, new_node_key): def mutate_add_node(key_, nodes_, conns_): """ add a node while do not influence the output of the network """ remain_node_space = jnp.isnan(nodes_[:, 0]).sum() remain_conn_space = jnp.isnan(conns_[:, 0]).sum() i_key, o_key, idx = self.choose_connection_key( key_, conns_ ) # choose a connection def successful_add_node(): # remove the original connection and record its attrs original_attrs = extract_conn_attrs(conns_[idx]) new_conns = delete_conn_by_pos(conns_, idx) # add a new node with identity attrs new_nodes = add_node( nodes_, new_node_key, self.genome.node_gene.new_identity_attrs(state) ) # add two new connections # first is with identity attrs new_conns = add_conn( new_conns, i_key, new_node_key, self.genome.conn_gene.new_identity_attrs(state), ) # second is with the origin attrs new_conns = add_conn( new_conns, new_node_key, o_key, original_attrs, ) return new_nodes, new_conns return jax.lax.cond( (idx == I_INF) | (remain_node_space < 1) | (remain_conn_space < 2), lambda: (nodes_, conns_), # do nothing successful_add_node, ) def mutate_delete_node(key_, nodes_, conns_): """ delete a node """ # randomly choose a node key, idx = self.choose_node_key( key_, nodes_, self.genome.input_idx, self.genome.output_idx, allow_input_keys=False, allow_output_keys=False, ) def successful_delete_node(): # delete the node new_nodes = delete_node_by_pos(nodes_, idx) # delete all connections new_conns = jnp.where( ((conns_[:, 0] == key) | (conns_[:, 1] == key))[:, None], jnp.nan, conns_, ) return new_nodes, new_conns return jax.lax.cond( idx == I_INF, # no available node to delete lambda: (nodes_, conns_), # do nothing successful_delete_node, ) def mutate_add_conn(key_, nodes_, conns_): """ add a connection while do not influence the output of the network """ remain_conn_space = jnp.isnan(conns_[:, 0]).sum() # randomly choose two nodes k1_, k2_ = jax.random.split(key_, num=2) # input node of the connection can be any node i_key, from_idx = self.choose_node_key( k1_, nodes_, self.genome.input_idx, self.genome.output_idx, allow_input_keys=True, allow_output_keys=True, ) # output node of the connection can be any node except input node o_key, to_idx = self.choose_node_key( k2_, nodes_, self.genome.input_idx, self.genome.output_idx, allow_input_keys=False, allow_output_keys=True, ) conn_pos = fetch_first((conns_[:, 0] == i_key) & (conns_[:, 1] == o_key)) is_already_exist = conn_pos != I_INF def nothing(): return nodes_, conns_ def successful(): # add a connection with zero attrs return nodes_, add_conn( conns_, i_key, o_key, self.genome.conn_gene.new_zero_attrs(state) ) if self.genome.network_type == "feedforward": u_conns = unflatten_conns(nodes_, conns_) conns_exist = u_conns != I_INF is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx) return jax.lax.cond( is_already_exist | is_cycle | (remain_conn_space < 1), nothing, successful, ) elif self.genome.network_type == "recurrent": return jax.lax.cond( is_already_exist | (remain_conn_space < 1), nothing, successful, ) else: raise ValueError(f"Invalid network type: {self.genome.network_type}") def mutate_delete_conn(key_, nodes_, conns_): # randomly choose a connection i_key, o_key, idx = self.choose_connection_key(key_, conns_) return jax.lax.cond( idx == I_INF, lambda: (nodes_, conns_), # nothing lambda: (nodes_, delete_conn_by_pos(conns_, idx)), # success ) k1, k2, k3, k4 = jax.random.split(randkey, num=4) r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,)) def nothing(_, nodes_, conns_): return nodes_, conns_ if self.node_add > 0: nodes, conns = jax.lax.cond( r1 < self.node_add, mutate_add_node, nothing, k1, nodes, conns ) if self.node_delete > 0: nodes, conns = jax.lax.cond( r2 < self.node_delete, mutate_delete_node, nothing, k2, nodes, conns ) if self.conn_add > 0: nodes, conns = jax.lax.cond( r3 < self.conn_add, mutate_add_conn, nothing, k3, nodes, conns ) if self.conn_delete > 0: nodes, conns = jax.lax.cond( r4 < self.conn_delete, mutate_delete_conn, nothing, k4, nodes, conns ) return nodes, conns def mutate_values(self, state, randkey, nodes, conns): k1, k2 = jax.random.split(randkey) nodes_randkeys = jax.random.split(k1, num=self.genome.max_nodes) conns_randkeys = jax.random.split(k2, num=self.genome.max_conns) node_attrs = vmap(extract_node_attrs)(nodes) new_node_attrs = vmap(self.genome.node_gene.mutate, in_axes=(None, 0, 0))( state, nodes_randkeys, node_attrs ) new_nodes = vmap(set_node_attrs)(nodes, new_node_attrs) conn_attrs = vmap(extract_conn_attrs)(conns) new_conn_attrs = vmap(self.genome.conn_gene.mutate, in_axes=(None, 0, 0))( state, conns_randkeys, conn_attrs ) new_conns = vmap(set_conn_attrs)(conns, new_conn_attrs) # nan nodes not changed new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes) new_conns = jnp.where(jnp.isnan(conns), jnp.nan, new_conns) return new_nodes, new_conns def choose_node_key( self, key, nodes, input_idx, output_idx, allow_input_keys: bool = False, allow_output_keys: bool = False, ): """ Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node. :param key: :param nodes: :param input_idx: :param output_idx: :param allow_input_keys: :param allow_output_keys: :return: return its key and position(idx) """ node_keys = nodes[:, 0] mask = ~jnp.isnan(node_keys) if not allow_input_keys: mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_idx)) if not allow_output_keys: mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_idx)) idx = fetch_random(key, mask) key = jnp.where(idx != I_INF, nodes[idx, 0], jnp.nan) return key, idx def choose_connection_key(self, key, conns): """ Randomly choose a connection key from the given connections. :return: i_key, o_key, idx """ idx = fetch_random(key, ~jnp.isnan(conns[:, 0])) i_key = jnp.where(idx != I_INF, conns[idx, 0], jnp.nan) o_key = jnp.where(idx != I_INF, conns[idx, 1], jnp.nan) return i_key, o_key, idx