from typing import Callable import jax, jax.numpy as jnp from utils import ( unflatten_conns, topological_sort, I_INF, extract_node_attrs, extract_conn_attrs, set_node_attrs, set_conn_attrs, attach_with_inf, ) from . import BaseGenome from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover class DefaultGenome(BaseGenome): """Default genome class, with the same behavior as the NEAT-Python""" network_type = "feedforward" def __init__( self, num_inputs: int, num_outputs: int, max_nodes=5, max_conns=4, node_gene: BaseNodeGene = DefaultNodeGene(), conn_gene: BaseConnGene = DefaultConnGene(), mutation: BaseMutation = DefaultMutation(), crossover: BaseCrossover = DefaultCrossover(), output_transform: Callable = None, ): super().__init__( num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene, mutation, crossover, ) if output_transform is not None: try: _ = output_transform(jnp.zeros(num_outputs)) except Exception as e: raise ValueError(f"Output transform function failed: {e}") self.output_transform = output_transform def transform(self, state, nodes, conns): u_conns = unflatten_conns(nodes, conns) conn_exist = u_conns != I_INF seqs = topological_sort(nodes, conn_exist) return seqs, nodes, conns, u_conns def restore(self, state, transformed): seqs, nodes, conns, u_conns = transformed return nodes, conns def forward(self, state, transformed, inputs): cal_seqs, nodes, conns, u_conns = transformed ini_vals = jnp.full((self.max_nodes,), jnp.nan) ini_vals = ini_vals.at[self.input_idx].set(inputs) nodes_attrs = jax.vmap(extract_node_attrs)(nodes) conns_attrs = jax.vmap(extract_conn_attrs)(conns) def cond_fun(carry): values, idx = carry return (idx < self.max_nodes) & (cal_seqs[idx] != I_INF) def body_func(carry): values, idx = carry i = cal_seqs[idx] def input_node(): z = self.node_gene.input_transform(state, nodes_attrs[i], values[i]) new_values = values.at[i].set(z) return new_values def otherwise(): conn_indices = u_conns[:, i] hit_attrs = attach_with_inf(conns_attrs, conn_indices) ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 0, 0))( state, hit_attrs, values ) z = self.node_gene.forward( state, nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx), ) new_values = values.at[i].set(z) return new_values values = jax.lax.cond(jnp.isin(i, self.input_idx), input_node, otherwise) return values, idx + 1 vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0)) if self.output_transform is None: return vals[self.output_idx] else: return self.output_transform(vals[self.output_idx]) def update_by_batch(self, state, batch_input, transformed): cal_seqs, nodes, conns, u_conns = transformed batch_size = batch_input.shape[0] batch_ini_vals = jnp.full((batch_size, self.max_nodes), jnp.nan) batch_ini_vals = batch_ini_vals.at[:, self.input_idx].set(batch_input) nodes_attrs = jax.vmap(extract_node_attrs)(nodes) conns_attrs = jax.vmap(extract_conn_attrs)(conns) def cond_fun(carry): batch_values, nodes_attrs_, conns_attrs_, idx = carry return (idx < self.max_nodes) & (cal_seqs[idx] != I_INF) def body_func(carry): batch_values, nodes_attrs_, conns_attrs_, idx = carry i = cal_seqs[idx] def input_node(): batch, new_attrs = self.node_gene.update_input_transform( state, nodes_attrs_[i], batch_values[:, i] ) return ( batch_values.at[:, i].set(batch), nodes_attrs_.at[i].set(new_attrs), conns_attrs_, ) def otherwise(): conn_indices = u_conns[:, i] hit_attrs = attach_with_inf(conns_attrs, conn_indices) batch_ins, new_conn_attrs = jax.vmap( self.conn_gene.update_by_batch, in_axes=(None, 0, 1), out_axes=(1, 0), )(state, hit_attrs, batch_values) batch_z, new_node_attrs = self.node_gene.update_by_batch( state, nodes_attrs_[i], batch_ins, is_output_node=jnp.isin(i, self.output_idx), ) return ( batch_values.at[:, i].set(batch_z), nodes_attrs_.at[i].set(new_node_attrs), conns_attrs_.at[conn_indices].set(new_conn_attrs), ) # the val of input nodes is obtained by the task, not by calculation (batch_values, nodes_attrs_, conns_attrs_) = jax.lax.cond( jnp.isin(i, self.input_idx), input_node, otherwise, ) return batch_values, nodes_attrs_, conns_attrs_, idx + 1 batch_vals, nodes_attrs, conns_attrs, _ = jax.lax.while_loop( cond_fun, body_func, (batch_ini_vals, nodes_attrs, conns_attrs, 0) ) nodes = jax.vmap(set_node_attrs)(nodes, nodes_attrs) conns = jax.vmap(set_conn_attrs)(conns, conns_attrs) new_transformed = (cal_seqs, nodes, conns, u_conns) if self.output_transform is None: return batch_vals[:, self.output_idx], new_transformed else: return ( jax.vmap(self.output_transform)(batch_vals[:, self.output_idx]), new_transformed, )