add input_transform and update_input_transform;

change the args for genome.forward.
Origin: (state, inputs, transformed)
New: (state, transformed, inputs)
This commit is contained in:
wls2002
2024-06-03 10:53:15 +08:00
parent a07a3b1cb2
commit edfb0596e7
16 changed files with 185 additions and 221 deletions

View File

@@ -1,7 +1,7 @@
from typing import Callable
import jax, jax.numpy as jnp
from utils import unflatten_conns, flatten_conns
from utils import unflatten_conns
from . import BaseGenome
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
@@ -47,11 +47,10 @@ class RecurrentGenome(BaseGenome):
def transform(self, state, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
return nodes, u_conns
return nodes, conns, u_conns
def restore(self, state, transformed):
nodes, u_conns = transformed
conns = flatten_conns(nodes, u_conns, C=self.max_conns)
nodes, conns, u_conns = transformed
return nodes, conns
def forward(self, state, inputs, transformed):