61 lines
1.8 KiB
Python
61 lines
1.8 KiB
Python
import jax, jax.numpy as jnp
|
|
from utils import unflatten_conns
|
|
|
|
from . import BaseGenome
|
|
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
|
|
|
|
|
|
class RecurrentGenome(BaseGenome):
|
|
"""Default genome class, with the same behavior as the NEAT-Python"""
|
|
|
|
network_type = 'recurrent'
|
|
|
|
def __init__(self,
|
|
num_inputs: int,
|
|
num_outputs: int,
|
|
max_nodes: int,
|
|
max_conns: int,
|
|
node_gene: BaseNodeGene = DefaultNodeGene(),
|
|
conn_gene: BaseConnGene = DefaultConnGene(),
|
|
activate_time: int = 10,
|
|
):
|
|
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
|
|
self.activate_time = activate_time
|
|
|
|
def transform(self, nodes, conns):
|
|
u_conns = unflatten_conns(nodes, conns)
|
|
|
|
# remove un-enable connections and remove enable attr
|
|
conn_enable = u_conns[0] == 1
|
|
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
|
|
|
return nodes, u_conns
|
|
|
|
def forward(self, inputs, transformed):
|
|
nodes, conns = transformed
|
|
|
|
N = nodes.shape[0]
|
|
vals = jnp.full((N,), jnp.nan)
|
|
nodes_attrs = nodes[:, 1:]
|
|
|
|
def body_func(_, values):
|
|
# set input values
|
|
values = values.at[self.input_idx].set(inputs)
|
|
|
|
# calculate connections
|
|
node_ins = jax.vmap(
|
|
jax.vmap(
|
|
self.conn_gene.forward,
|
|
in_axes=(1, None)
|
|
),
|
|
in_axes=(1, 0)
|
|
)(conns, values)
|
|
|
|
# calculate nodes
|
|
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T)
|
|
return values
|
|
|
|
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)
|
|
|
|
return vals[self.output_idx]
|