from typing import Callable import jax, jax.numpy as jnp from utils import unflatten_conns, topological_sort, I_INF from . import BaseGenome from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover class DefaultGenome(BaseGenome): """Default genome class, with the same behavior as the NEAT-Python""" network_type = "feedforward" def __init__( self, num_inputs: int, num_outputs: int, max_nodes=5, max_conns=4, node_gene: BaseNodeGene = DefaultNodeGene(), conn_gene: BaseConnGene = DefaultConnGene(), mutation: BaseMutation = DefaultMutation(), crossover: BaseCrossover = DefaultCrossover(), output_transform: Callable = None, ): super().__init__( num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene, mutation, crossover, ) if output_transform is not None: try: _ = output_transform(jnp.zeros(num_outputs)) except Exception as e: raise ValueError(f"Output transform function failed: {e}") self.output_transform = output_transform def transform(self, state, nodes, conns): u_conns = unflatten_conns(nodes, conns) conn_enable = u_conns[0] == 1 # remove enable attr u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan) seqs = topological_sort(nodes, conn_enable) return seqs, nodes, u_conns def forward(self, state, inputs, transformed): cal_seqs, nodes, conns = transformed N = nodes.shape[0] ini_vals = jnp.full((N,), jnp.nan) ini_vals = ini_vals.at[self.input_idx].set(inputs) nodes_attrs = nodes[:, 1:] def cond_fun(carry): values, idx = carry return (idx < N) & (cal_seqs[idx] != I_INF) def body_func(carry): values, idx = carry i = cal_seqs[idx] def hit(): ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 1, 0))( state, conns[:, :, i], values ) z = self.node_gene.forward( state, nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx), ) new_values = values.at[i].set(z) return new_values # the val of input nodes is obtained by the task, not by calculation values = jax.lax.cond(jnp.isin(i, self.input_idx), lambda: values, hit) return values, idx + 1 vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0)) if self.output_transform is None: return vals[self.output_idx] else: return self.output_transform(vals[self.output_idx])