@@ -482,11 +482,13 @@ def test_tanh_saturation_parameterization_transformation(self, x, b, c):
482482 y2 = tanh_saturation_baselined (x , * param_x0 ).eval ()
483483 y3 = tanh_saturation_baselined (x , * param_x1 ).eval ()
484484 y4 = tanh_saturation (x , * param_classic1 ).eval ()
485- np .testing .assert_allclose (y1 , y2 )
486- np .testing .assert_allclose (y2 , y3 )
487- np .testing .assert_allclose (y3 , y4 )
488- np .testing .assert_allclose (param_classic1 .b .eval (), b )
489- np .testing .assert_allclose (param_classic1 .c .eval (), c , rtol = 1e-06 )
485+ # Use consistent tolerances for all comparisons to account for
486+ # accumulated floating-point errors in round-trip transformations
487+ np .testing .assert_allclose (y1 , y2 , rtol = 1e-6 )
488+ np .testing .assert_allclose (y2 , y3 , rtol = 1e-6 )
489+ np .testing .assert_allclose (y3 , y4 , rtol = 1e-6 )
490+ np .testing .assert_allclose (param_classic1 .b .eval (), b , rtol = 1e-6 )
491+ np .testing .assert_allclose (param_classic1 .c .eval (), c , rtol = 1e-6 )
490492
491493 @pytest .mark .parametrize (
492494 "x, alpha, lam, expected" ,
0 commit comments