@@ -152,7 +152,13 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None
152152 # spec in places where it either deviates from or is more strict than
153153 # NumPy behavior
154154
155- def _check_allowed_dtypes (self , other : bool | int | float | Array , dtype_category : str , op : str ) -> Array :
155+ def _check_allowed_dtypes (
156+ self ,
157+ other : bool | int | float | Array ,
158+ dtype_category : str ,
159+ op : str ,
160+ check_promotion : bool = True ,
161+ ) -> Array :
156162 """
157163 Helper function for operators to only allow specific input dtypes
158164
@@ -176,7 +182,8 @@ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_categor
176182 # This will raise TypeError for type combinations that are not allowed
177183 # to promote in the spec (even if the NumPy array operator would
178184 # promote them).
179- res_dtype = _result_type (self .dtype , other .dtype )
185+ if check_promotion :
186+ res_dtype = _result_type (self .dtype , other .dtype )
180187 if op .startswith ("__i" ):
181188 # Note: NumPy will allow in-place operators in some cases where
182189 # the type promoted operator does not match the left-hand side
@@ -604,7 +611,7 @@ def __ge__(self: Array, other: Union[int, float, Array], /) -> Array:
604611 """
605612 Performs the operation __ge__.
606613 """
607- other = self ._check_allowed_dtypes (other , "real numeric" , "__ge__" )
614+ other = self ._check_allowed_dtypes (other , "real numeric" , "__ge__" , check_promotion = False )
608615 if other is NotImplemented :
609616 return other
610617 self , other = self ._normalize_two_args (self , other )
@@ -638,7 +645,7 @@ def __gt__(self: Array, other: Union[int, float, Array], /) -> Array:
638645 """
639646 Performs the operation __gt__.
640647 """
641- other = self ._check_allowed_dtypes (other , "real numeric" , "__gt__" )
648+ other = self ._check_allowed_dtypes (other , "real numeric" , "__gt__" , check_promotion = False )
642649 if other is NotImplemented :
643650 return other
644651 self , other = self ._normalize_two_args (self , other )
@@ -692,7 +699,7 @@ def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
692699 """
693700 Performs the operation __le__.
694701 """
695- other = self ._check_allowed_dtypes (other , "real numeric" , "__le__" )
702+ other = self ._check_allowed_dtypes (other , "real numeric" , "__le__" , check_promotion = False )
696703 if other is NotImplemented :
697704 return other
698705 self , other = self ._normalize_two_args (self , other )
@@ -714,7 +721,7 @@ def __lt__(self: Array, other: Union[int, float, Array], /) -> Array:
714721 """
715722 Performs the operation __lt__.
716723 """
717- other = self ._check_allowed_dtypes (other , "real numeric" , "__lt__" )
724+ other = self ._check_allowed_dtypes (other , "real numeric" , "__lt__" , check_promotion = False )
718725 if other is NotImplemented :
719726 return other
720727 self , other = self ._normalize_two_args (self , other )
0 commit comments