@@ -815,23 +815,6 @@ def is_writeable_array(x):
815815 return True
816816
817817
818- def _parse_copy_param (x , copy : bool | None | Literal ["_force_false" ]) -> bool :
819- """Preprocess and validate a copy parameter, in line with the same
820- parameter in np.asarray(), np.astype(), etc.
821- """
822- if copy is True :
823- return True
824- elif copy is False :
825- if not is_writeable_array (x ):
826- raise ValueError ("Cannot avoid modifying parameter in place" )
827- return False
828- elif copy is None :
829- return not is_writeable_array (x )
830- elif copy == "_force_false" :
831- return False
832- raise ValueError (f"Invalid value for copy: { copy !r} " )
833-
834-
835818_undef = object ()
836819
837820
@@ -947,7 +930,15 @@ def _common(
947930 "(same for all other methods)."
948931 )
949932
950- copy = _parse_copy_param (self .x , copy )
933+ if copy is False :
934+ if not is_writeable_array (self .x ):
935+ raise ValueError ("Cannot avoid modifying parameter in place" )
936+ elif copy is None :
937+ copy = not is_writeable_array (self .x )
938+ elif copy == "_force_false" :
939+ copy = False
940+ elif copy is not True :
941+ raise ValueError (f"Invalid value for copy: { copy !r} " )
951942
952943 if copy and is_jax_array (self .x ):
953944 # Use JAX's at[]
@@ -956,6 +947,9 @@ def _common(
956947 return getattr (at_ , at_op )(* args , ** kwargs ), None
957948
958949 # Emulate at[] behaviour for non-JAX arrays
950+ # FIXME We blindly expect the output of x.copy() to be always writeable.
951+ # This holds true for read-only numpy arrays, but not necessarily for
952+ # other backends.
959953 x = self .x .copy () if copy else self .x
960954 return None , x
961955
@@ -1047,35 +1041,6 @@ def max(self, y, /, **kwargs):
10471041 return self ._iop ("max" , xp .maximum , y , ** kwargs )
10481042
10491043
1050- def where (condition , x = None , y = None , / , copy : bool | None = True ):
1051- """Return elements from x when condition is True and from y when
1052- it is False.
1053-
1054- This is a wrapper around xp.where that adds the copy parameter:
1055-
1056- None
1057- x *may* be modified in place if it is possible and beneficial
1058- for performance. You should not use x after calling this function.
1059- True
1060- Ensure that the inputs are not modified.
1061- This is the default, in line with np.where.
1062- False
1063- Raise ValueError if a copy cannot be avoided.
1064- """
1065- if x is None and y is None :
1066- xp = array_namespace (condition , use_compat = False )
1067- return xp .where (condition )
1068-
1069- copy = _parse_copy_param (x , copy )
1070- xp = array_namespace (condition , x , y , use_compat = False )
1071- if copy :
1072- return xp .where (condition , x , y )
1073- else :
1074- condition , x , y = xp .broadcast_arrays (condition , x , y )
1075- x [condition ] = y [condition ]
1076- return x
1077-
1078-
10791044__all__ = [
10801045 "array_namespace" ,
10811046 "device" ,
@@ -1100,7 +1065,6 @@ def where(condition, x=None, y=None, /, copy: bool | None = True):
11001065 "size" ,
11011066 "to_device" ,
11021067 "at" ,
1103- "where" ,
11041068]
11051069
11061070_all_ignore = ['inspect' , 'math' , 'operator' , 'warnings' , 'sys' ]
0 commit comments