fix bug in pipeline.py for show_problem_details

https://github.com/EMI-Group/tensorneat/issues/15
This commit is contained in:
wls2002
2025-02-14 12:27:32 +08:00
parent ede30b424c
commit f17f31bb2a

View File

@@ -106,7 +106,7 @@ class Pipeline(StatefulBaseClass):
if self.show_problem_details: if self.show_problem_details:
self.compiled_pop_transform_func = ( self.compiled_pop_transform_func = (
jax.jit(jax.vmap(self.algorithm.transform, in_axes=(None, 0))) jax.jit(jax.vmap(self.algorithm.transform, in_axes=(None, 0)))
.lower(self.algorithm.ask(state)) .lower(state, self.algorithm.ask(state))
.compile() .compile()
) )