add input_transform and update_input_transform;

change the args for genome.forward.
Origin: (state, inputs, transformed)
New: (state, transformed, inputs)
This commit is contained in:
wls2002
2024-06-03 10:53:15 +08:00
parent a07a3b1cb2
commit edfb0596e7
16 changed files with 185 additions and 221 deletions

View File

@@ -54,8 +54,8 @@ class HyperNEAT(BaseAlgorithm):
def transform(self, state, individual):
transformed = self.neat.transform(state, individual)
query_res = jax.vmap(self.neat.forward, in_axes=(None, 0, None))(
state, self.substrate.query_coors, transformed
query_res = jax.vmap(self.neat.forward, in_axes=(None, None, 0))(
state, transformed, self.substrate.query_coors
)
# mute the connection with weight below threshold
query_res = jnp.where(