From f17f31bb2a52572d2b3d424c1466e6ae3e6dbf1d Mon Sep 17 00:00:00 2001 From: wls2002 Date: Fri, 14 Feb 2025 12:27:32 +0800 Subject: [PATCH] fix bug in pipeline.py for show_problem_details https://github.com/EMI-Group/tensorneat/issues/15 --- src/tensorneat/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tensorneat/pipeline.py b/src/tensorneat/pipeline.py index 1e00729..86ee90f 100644 --- a/src/tensorneat/pipeline.py +++ b/src/tensorneat/pipeline.py @@ -106,7 +106,7 @@ class Pipeline(StatefulBaseClass): if self.show_problem_details: self.compiled_pop_transform_func = ( jax.jit(jax.vmap(self.algorithm.transform, in_axes=(None, 0))) - .lower(self.algorithm.ask(state)) + .lower(state, self.algorithm.ask(state)) .compile() )