diff --git a/src/tensorneat/pipeline.py b/src/tensorneat/pipeline.py index 1e00729..86ee90f 100644 --- a/src/tensorneat/pipeline.py +++ b/src/tensorneat/pipeline.py @@ -106,7 +106,7 @@ class Pipeline(StatefulBaseClass): if self.show_problem_details: self.compiled_pop_transform_func = ( jax.jit(jax.vmap(self.algorithm.transform, in_axes=(None, 0))) - .lower(self.algorithm.ask(state)) + .lower(state, self.algorithm.ask(state)) .compile() )