diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index d86bc27..6c87b3f 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -352,5 +352,7 @@ def sign(x: Array, /) -> Array: raise TypeError("Only numeric dtypes are allowed in sign") # Special treatment to work around non-compliant NumPy 1.x behaviour if x.dtype in _complex_floating_dtypes: - return x/abs(x) + _x = x._array + _result = _x / np.abs(np.where(_x != 0, _x, np.asarray(1.0, dtype=_x.dtype))) + return Array._new(_result, device=x.device) return Array._new(np.sign(x._array), device=x.device)