@@ -859,12 +859,13 @@ def __rshift__(self: Array, other: Union[int, Array], /) -> Array:
859859 """
860860 Performs the operation __rshift__.
861861 """
862+ other = self ._check_device (other )
862863 other = self ._check_allowed_dtypes (other , "integer" , "__rshift__" )
863864 if other is NotImplemented :
864865 return other
865866 self , other = self ._normalize_two_args (self , other )
866867 res = self ._array .__rshift__ (other ._array )
867- return self .__class__ ._new (res )
868+ return self .__class__ ._new (res , device = self . device )
868869
869870 def __setitem__ (
870871 self ,
@@ -889,41 +890,45 @@ def __sub__(self: Array, other: Union[int, float, Array], /) -> Array:
889890 """
890891 Performs the operation __sub__.
891892 """
893+ other = self ._check_device (other )
892894 other = self ._check_allowed_dtypes (other , "numeric" , "__sub__" )
893895 if other is NotImplemented :
894896 return other
895897 self , other = self ._normalize_two_args (self , other )
896898 res = self ._array .__sub__ (other ._array )
897- return self .__class__ ._new (res )
899+ return self .__class__ ._new (res , device = self . device )
898900
899901 # PEP 484 requires int to be a subtype of float, but __truediv__ should
900902 # not accept int.
901903 def __truediv__ (self : Array , other : Union [float , Array ], / ) -> Array :
902904 """
903905 Performs the operation __truediv__.
904906 """
907+ other = self ._check_device (other )
905908 other = self ._check_allowed_dtypes (other , "floating-point" , "__truediv__" )
906909 if other is NotImplemented :
907910 return other
908911 self , other = self ._normalize_two_args (self , other )
909912 res = self ._array .__truediv__ (other ._array )
910- return self .__class__ ._new (res )
913+ return self .__class__ ._new (res , device = self . device )
911914
912915 def __xor__ (self : Array , other : Union [int , bool , Array ], / ) -> Array :
913916 """
914917 Performs the operation __xor__.
915918 """
919+ other = self ._check_device (other )
916920 other = self ._check_allowed_dtypes (other , "integer or boolean" , "__xor__" )
917921 if other is NotImplemented :
918922 return other
919923 self , other = self ._normalize_two_args (self , other )
920924 res = self ._array .__xor__ (other ._array )
921- return self .__class__ ._new (res )
925+ return self .__class__ ._new (res , device = self . device )
922926
923927 def __iadd__ (self : Array , other : Union [int , float , Array ], / ) -> Array :
924928 """
925929 Performs the operation __iadd__.
926930 """
931+ other = self ._check_device (other )
927932 other = self ._check_allowed_dtypes (other , "numeric" , "__iadd__" )
928933 if other is NotImplemented :
929934 return other
@@ -934,17 +939,19 @@ def __radd__(self: Array, other: Union[int, float, Array], /) -> Array:
934939 """
935940 Performs the operation __radd__.
936941 """
942+ other = self ._check_device (other )
937943 other = self ._check_allowed_dtypes (other , "numeric" , "__radd__" )
938944 if other is NotImplemented :
939945 return other
940946 self , other = self ._normalize_two_args (self , other )
941947 res = self ._array .__radd__ (other ._array )
942- return self .__class__ ._new (res )
948+ return self .__class__ ._new (res , device = self . device )
943949
944950 def __iand__ (self : Array , other : Union [int , bool , Array ], / ) -> Array :
945951 """
946952 Performs the operation __iand__.
947953 """
954+ other = self ._check_device (other )
948955 other = self ._check_allowed_dtypes (other , "integer or boolean" , "__iand__" )
949956 if other is NotImplemented :
950957 return other
@@ -955,17 +962,19 @@ def __rand__(self: Array, other: Union[int, bool, Array], /) -> Array:
955962 """
956963 Performs the operation __rand__.
957964 """
965+ other = self ._check_device (other )
958966 other = self ._check_allowed_dtypes (other , "integer or boolean" , "__rand__" )
959967 if other is NotImplemented :
960968 return other
961969 self , other = self ._normalize_two_args (self , other )
962970 res = self ._array .__rand__ (other ._array )
963- return self .__class__ ._new (res )
971+ return self .__class__ ._new (res , device = self . device )
964972
965973 def __ifloordiv__ (self : Array , other : Union [int , float , Array ], / ) -> Array :
966974 """
967975 Performs the operation __ifloordiv__.
968976 """
977+ other = self ._check_device (other )
969978 other = self ._check_allowed_dtypes (other , "real numeric" , "__ifloordiv__" )
970979 if other is NotImplemented :
971980 return other
@@ -976,17 +985,19 @@ def __rfloordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
976985 """
977986 Performs the operation __rfloordiv__.
978987 """
988+ other = self ._check_device (other )
979989 other = self ._check_allowed_dtypes (other , "real numeric" , "__rfloordiv__" )
980990 if other is NotImplemented :
981991 return other
982992 self , other = self ._normalize_two_args (self , other )
983993 res = self ._array .__rfloordiv__ (other ._array )
984- return self .__class__ ._new (res )
994+ return self .__class__ ._new (res , device = self . device )
985995
986996 def __ilshift__ (self : Array , other : Union [int , Array ], / ) -> Array :
987997 """
988998 Performs the operation __ilshift__.
989999 """
1000+ other = self ._check_device (other )
9901001 other = self ._check_allowed_dtypes (other , "integer" , "__ilshift__" )
9911002 if other is NotImplemented :
9921003 return other
@@ -997,17 +1008,19 @@ def __rlshift__(self: Array, other: Union[int, Array], /) -> Array:
9971008 """
9981009 Performs the operation __rlshift__.
9991010 """
1011+ other = self ._check_device (other )
10001012 other = self ._check_allowed_dtypes (other , "integer" , "__rlshift__" )
10011013 if other is NotImplemented :
10021014 return other
10031015 self , other = self ._normalize_two_args (self , other )
10041016 res = self ._array .__rlshift__ (other ._array )
1005- return self .__class__ ._new (res )
1017+ return self .__class__ ._new (res , device = self . device )
10061018
10071019 def __imatmul__ (self : Array , other : Array , / ) -> Array :
10081020 """
10091021 Performs the operation __imatmul__.
10101022 """
1023+ other = self ._check_device (other )
10111024 # matmul is not defined for scalars, but without this, we may get
10121025 # the wrong error message from asarray.
10131026 other = self ._check_allowed_dtypes (other , "numeric" , "__imatmul__" )
0 commit comments