Skip to content

Commit 5ef2a2f

Browse files
committed
BUG: fix sign(0+0j) to be zero, not nan
closes gh-166
1 parent bfc524c commit 5ef2a2f

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

array_api_strict/_elementwise_functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,5 +352,7 @@ def sign(x: Array, /) -> Array:
352352
raise TypeError("Only numeric dtypes are allowed in sign")
353353
# Special treatment to work around non-compliant NumPy 1.x behaviour
354354
if x.dtype in _complex_floating_dtypes:
355-
return x/abs(x)
355+
_x = x._array
356+
_result = _x / np.abs(np.where(_x != 0, _x, np.asarray(1.0, dtype=_x.dtype)))
357+
return Array._new(_result, device=x.device)
356358
return Array._new(np.sign(x._array), device=x.device)

0 commit comments

Comments
 (0)