|
19 | 19 | from ._lib._utils._typing import Array, DType |
20 | 20 |
|
21 | 21 | __all__ = [ |
| 22 | + "atleast_nd", |
22 | 23 | "cov", |
23 | 24 | "expand_dims", |
24 | 25 | "isclose", |
|
29 | 30 | ] |
30 | 31 |
|
31 | 32 |
|
| 33 | +def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array: |
| 34 | + """ |
| 35 | + Recursively expand the dimension of an array to at least `ndim`. |
| 36 | +
|
| 37 | + Parameters |
| 38 | + ---------- |
| 39 | + x : array |
| 40 | + Input array. |
| 41 | + ndim : int |
| 42 | + The minimum number of dimensions for the result. |
| 43 | + xp : array_namespace, optional |
| 44 | + The standard-compatible namespace for `x`. Default: infer. |
| 45 | +
|
| 46 | + Returns |
| 47 | + ------- |
| 48 | + array |
| 49 | + An array with ``res.ndim`` >= `ndim`. |
| 50 | + If ``x.ndim`` >= `ndim`, `x` is returned. |
| 51 | + If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes |
| 52 | + until ``res.ndim`` equals `ndim`. |
| 53 | +
|
| 54 | + Examples |
| 55 | + -------- |
| 56 | + >>> import array_api_strict as xp |
| 57 | + >>> import array_api_extra as xpx |
| 58 | + >>> x = xp.asarray([1]) |
| 59 | + >>> xpx.atleast_nd(x, ndim=3, xp=xp) |
| 60 | + Array([[[1]]], dtype=array_api_strict.int64) |
| 61 | +
|
| 62 | + >>> x = xp.asarray([[[1, 2], |
| 63 | + ... [3, 4]]]) |
| 64 | + >>> xpx.atleast_nd(x, ndim=1, xp=xp) is x |
| 65 | + True |
| 66 | + """ |
| 67 | + if xp is None: |
| 68 | + xp = array_namespace(x) |
| 69 | + |
| 70 | + if 1 <= ndim <= 3 and ( |
| 71 | + is_numpy_namespace(xp) |
| 72 | + or is_jax_namespace(xp) |
| 73 | + or is_dask_namespace(xp) |
| 74 | + or is_cupy_namespace(xp) |
| 75 | + or is_torch_namespace(xp) |
| 76 | + ): |
| 77 | + return getattr(xp, f"atleast_{ndim}d")(x) |
| 78 | + |
| 79 | + return _funcs.atleast_nd(x, ndim=ndim, xp=xp) |
| 80 | + |
| 81 | + |
32 | 82 | def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: |
33 | 83 | """ |
34 | 84 | Estimate a covariance matrix. |
@@ -197,55 +247,6 @@ def expand_dims( |
197 | 247 | return _funcs.expand_dims(a, axis=axis, xp=xp) |
198 | 248 |
|
199 | 249 |
|
200 | | -def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array: |
201 | | - """ |
202 | | - Recursively expand the dimension of an array to at least `ndim`. |
203 | | -
|
204 | | - Parameters |
205 | | - ---------- |
206 | | - x : array |
207 | | - Input array. |
208 | | - ndim : int |
209 | | - The minimum number of dimensions for the result. |
210 | | - xp : array_namespace, optional |
211 | | - The standard-compatible namespace for `x`. Default: infer. |
212 | | -
|
213 | | - Returns |
214 | | - ------- |
215 | | - array |
216 | | - An array with ``res.ndim`` >= `ndim`. |
217 | | - If ``x.ndim`` >= `ndim`, `x` is returned. |
218 | | - If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes |
219 | | - until ``res.ndim`` equals `ndim`. |
220 | | -
|
221 | | - Examples |
222 | | - -------- |
223 | | - >>> import array_api_strict as xp |
224 | | - >>> import array_api_extra as xpx |
225 | | - >>> x = xp.asarray([1]) |
226 | | - >>> xpx.atleast_nd(x, ndim=3, xp=xp) |
227 | | - Array([[[1]]], dtype=array_api_strict.int64) |
228 | | -
|
229 | | - >>> x = xp.asarray([[[1, 2], |
230 | | - ... [3, 4]]]) |
231 | | - >>> xpx.atleast_nd(x, ndim=1, xp=xp) is x |
232 | | - True |
233 | | - """ |
234 | | - if xp is None: |
235 | | - xp = array_namespace(x) |
236 | | - |
237 | | - if 1 <= ndim <= 3 and ( |
238 | | - is_numpy_namespace(xp) |
239 | | - or is_jax_namespace(xp) |
240 | | - or is_dask_namespace(xp) |
241 | | - or is_cupy_namespace(xp) |
242 | | - or is_torch_namespace(xp) |
243 | | - ): |
244 | | - return getattr(xp, f"atleast_{ndim}d")(x) |
245 | | - |
246 | | - return _funcs.atleast_nd(x, ndim=ndim, xp=xp) |
247 | | - |
248 | | - |
249 | 250 | def isclose( |
250 | 251 | a: Array | complex, |
251 | 252 | b: Array | complex, |
|
0 commit comments