4040isdtype = get_xp (np )(_aliases .isdtype )
4141unstack = get_xp (da )(_aliases .unstack )
4242
43+ # da.astype doesn't respect copy=True
4344def astype (
4445 x : Array ,
4546 dtype : Dtype ,
4647 / ,
4748 * ,
4849 copy : bool = True ,
49- device : Device | None = None
50+ device : Optional [ Device ] = None
5051) -> Array :
52+ """
53+ Array API compatibility wrapper for astype().
54+
55+ See the corresponding documentation in the array library and/or the array API
56+ specification for more details.
57+ """
5158 # TODO: respect device keyword?
59+
5260 if not copy and dtype == x .dtype :
5361 return x
54- # dask astype doesn't respect copy=True,
55- # so call copy manually afterwards
5662 x = x .astype (dtype )
5763 return x .copy () if copy else x
5864
@@ -61,20 +67,24 @@ def astype(
6167# This arange func is modified from the common one to
6268# not pass stop/step as keyword arguments, which will cause
6369# an error with dask
64-
65- # TODO: delete the xp stuff, it shouldn't be necessary
66- def _dask_arange (
70+ def arange (
6771 start : Union [int , float ],
6872 / ,
6973 stop : Optional [Union [int , float ]] = None ,
7074 step : Union [int , float ] = 1 ,
7175 * ,
72- xp ,
7376 dtype : Optional [Dtype ] = None ,
7477 device : Optional [Device ] = None ,
7578 ** kwargs ,
7679) -> Array :
77- _check_device (xp , device )
80+ """
81+ Array API compatibility wrapper for arange().
82+
83+ See the corresponding documentation in the array library and/or the array API
84+ specification for more details.
85+ """
86+ # TODO: respect device keyword?
87+
7888 args = [start ]
7989 if stop is not None :
8090 args .append (stop )
@@ -83,13 +93,12 @@ def _dask_arange(
8393 # prepend the default value for start which is 0
8494 args .insert (0 , 0 )
8595 args .append (step )
86- return xp .arange (* args , dtype = dtype , ** kwargs )
8796
88- arange = get_xp ( da )( _dask_arange )
89- eye = get_xp ( da )( _aliases . eye )
97+ return da . arange ( * args , dtype = dtype , ** kwargs )
98+
9099
91- linspace = get_xp (da )(_aliases .linspace )
92100eye = get_xp (da )(_aliases .eye )
101+ linspace = get_xp (da )(_aliases .linspace )
93102UniqueAllResult = get_xp (da )(_aliases .UniqueAllResult )
94103UniqueCountsResult = get_xp (da )(_aliases .UniqueCountsResult )
95104UniqueInverseResult = get_xp (da )(_aliases .UniqueInverseResult )
@@ -112,7 +121,6 @@ def _dask_arange(
112121reshape = get_xp (da )(_aliases .reshape )
113122matrix_transpose = get_xp (da )(_aliases .matrix_transpose )
114123vecdot = get_xp (da )(_aliases .vecdot )
115-
116124nonzero = get_xp (da )(_aliases .nonzero )
117125ceil = get_xp (np )(_aliases .ceil )
118126floor = get_xp (np )(_aliases .floor )
@@ -121,6 +129,7 @@ def _dask_arange(
121129tensordot = get_xp (np )(_aliases .tensordot )
122130sign = get_xp (np )(_aliases .sign )
123131
132+
124133# asarray also adds the copy keyword, which is not present in numpy 1.0.
125134def asarray (
126135 obj : Union [
@@ -135,7 +144,7 @@ def asarray(
135144 * ,
136145 dtype : Optional [Dtype ] = None ,
137146 device : Optional [Device ] = None ,
138- copy : " Optional[Union[bool, np._CopyMode]]" = None ,
147+ copy : Optional [Union [bool , np ._CopyMode ]] = None ,
139148 ** kwargs ,
140149) -> Array :
141150 """
@@ -144,6 +153,8 @@ def asarray(
144153 See the corresponding documentation in the array library and/or the array API
145154 specification for more details.
146155 """
156+ # TODO: respect device keyword?
157+
147158 if isinstance (obj , da .Array ):
148159 if dtype is not None and dtype != obj .dtype :
149160 if copy is False :
@@ -183,15 +194,18 @@ def asarray(
183194# Furthermore, the masking workaround in common._aliases.clip cannot work with
184195# dask (meaning uint64 promoting to float64 is going to just be unfixed for
185196# now).
186- @get_xp (da )
187197def clip (
188198 x : Array ,
189199 / ,
190200 min : Optional [Union [int , float , Array ]] = None ,
191201 max : Optional [Union [int , float , Array ]] = None ,
192- * ,
193- xp ,
194202) -> Array :
203+ """
204+ Array API compatibility wrapper for clip().
205+
206+ See the corresponding documentation in the array library and/or the array API
207+ specification for more details.
208+ """
195209 def _isscalar (a ):
196210 return isinstance (a , (int , float , type (None )))
197211 min_shape = () if _isscalar (min ) else min .shape
@@ -201,19 +215,19 @@ def _isscalar(a):
201215 result_shape = np .broadcast_shapes (x .shape , min_shape , max_shape )
202216
203217 if min is not None :
204- min = xp .broadcast_to (xp .asarray (min ), result_shape )
218+ min = da .broadcast_to (da .asarray (min ), result_shape )
205219 if max is not None :
206- max = xp .broadcast_to (xp .asarray (max ), result_shape )
220+ max = da .broadcast_to (da .asarray (max ), result_shape )
207221
208222 if min is None and max is None :
209- return xp .positive (x )
223+ return da .positive (x )
210224
211225 if min is None :
212- return astype (xp .minimum (x , max ), x .dtype )
226+ return astype (da .minimum (x , max ), x .dtype )
213227 if max is None :
214- return astype (xp .maximum (x , min ), x .dtype )
228+ return astype (da .maximum (x , min ), x .dtype )
215229
216- return astype (xp .minimum (xp .maximum (x , min ), max ), x .dtype )
230+ return astype (da .minimum (da .maximum (x , min ), max ), x .dtype )
217231
218232# exclude these from all since dask.array has no sorting functions
219233_da_unsupported = ['sort' , 'argsort' ]
0 commit comments