remove create_func....

This commit is contained in:
wls2002
2023-08-02 15:02:08 +08:00
parent 1499e062fe
commit c7fb1ddabe
22 changed files with 425 additions and 21 deletions

View File

@@ -0,0 +1,57 @@
from dataclasses import dataclass
import jax
from jax import numpy as jnp, vmap
from .normal import NormalGene, NormalGeneConfig
from core import State, Genome
from utils import unflatten_conns, act, agg
@dataclass(frozen=True)
class RecurrentGeneConfig(NormalGeneConfig):
activate_times: int = 10
def __post_init__(self):
super().__post_init__()
assert self.activate_times > 0
class RecurrentGene(NormalGene):
def __init__(self, config: RecurrentGeneConfig):
self.config = config
super().__init__(config)
def forward_transform(self, state: State, genome: Genome):
u_conns = unflatten_conns(genome.nodes, genome.conns)
# remove un-enable connections and remove enable attr
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
return genome.nodes, u_conns
def forward(self, state: State, inputs, transformed):
nodes, conns = transformed
batch_act, batch_agg = vmap(act, in_axes=(0, 0, None)), vmap(agg, in_axes=(0, 0, None))
input_idx = state.input_idx
output_idx = state.output_idx
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
weights = conns[0, :]
def body_func(i, values):
values = values.at[input_idx].set(inputs)
nodes_ins = values * weights.T
values = batch_agg(nodes[:, 4], nodes_ins, self.agg_funcs) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(nodes[:, 3], values, self.act_funcs) # z = act(z)
return values
vals = jax.lax.fori_loop(0, self.config.activate_times, body_func, vals)
return vals[output_idx]