add input_transform and update_input_transform;
change the args for genome.forward. Origin: (state, inputs, transformed) New: (state, transformed, inputs)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import unflatten_conns, flatten_conns
|
||||
from utils import unflatten_conns
|
||||
|
||||
from . import BaseGenome
|
||||
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
|
||||
@@ -47,11 +47,10 @@ class RecurrentGenome(BaseGenome):
|
||||
|
||||
def transform(self, state, nodes, conns):
|
||||
u_conns = unflatten_conns(nodes, conns)
|
||||
return nodes, u_conns
|
||||
return nodes, conns, u_conns
|
||||
|
||||
def restore(self, state, transformed):
|
||||
nodes, u_conns = transformed
|
||||
conns = flatten_conns(nodes, u_conns, C=self.max_conns)
|
||||
nodes, conns, u_conns = transformed
|
||||
return nodes, conns
|
||||
|
||||
def forward(self, state, inputs, transformed):
|
||||
|
||||
Reference in New Issue
Block a user