File tree Expand file tree Collapse file tree 1 file changed +17
-4
lines changed Expand file tree Collapse file tree 1 file changed +17
-4
lines changed Original file line number Diff line number Diff line change @@ -111,11 +111,24 @@ def assert_numpy_allclose(
111111 )
112112
113113 if not (is_allclose ):
114- diff = np .abs (
115- np .asarray (output_parameters [i ]) - np .asarray (output_references [i ])
116- )
114+ output_parameters_i_arr = np .asarray (output_parameters [i ])
115+ output_references_i_arr = np .asarray (output_references [i ])
116+
117+ diff = np .abs (output_parameters_i_arr - output_references_i_arr )
117118 abs_diff = np .sum (diff )
118- rel_diff = np .sum (diff / np .abs (output_references [i ]))
119+ rel_diff_dividend = np .max (
120+ np .vstack (
121+ (
122+ np .abs (output_parameters_i_arr ),
123+ np .abs (output_references_i_arr ),
124+ )
125+ ),
126+ axis = 0 ,
127+ )
128+ # when both are zero the diff is also zero, so we set it to 1
129+ # so no division by zero error is raised
130+ rel_diff_dividend [rel_diff_dividend == 0.0 ] = 1.0
131+ rel_diff = np .sum (diff / rel_diff_dividend )
119132
120133 message = (
121134 f"Output is not close to reference absolute difference "
You can’t perform that action at this time.
0 commit comments