|
16 | 16 |
|
17 | 17 | """ |
18 | 18 |
|
| 19 | +__all__ = [] |
| 20 | + |
19 | 21 | # Warning: __array_api_version__ could change globally with |
20 | 22 | # set_array_api_strict_flags(). This should always be accessed as an |
21 | 23 | # attribute, like xp.__array_api_version__, or using |
22 | 24 | # array_api_strict.get_array_api_strict_flags()['api_version']. |
23 | 25 | from ._flags import API_VERSION as __array_api_version__ |
24 | 26 |
|
25 | | -__all__ = ["__array_api_version__"] |
| 27 | +__all__ += ["__array_api_version__"] |
26 | 28 |
|
27 | 29 | from ._constants import e, inf, nan, pi, newaxis |
28 | 30 |
|
|
137 | 139 | bitwise_right_shift, |
138 | 140 | bitwise_xor, |
139 | 141 | ceil, |
| 142 | + clip, |
140 | 143 | conj, |
| 144 | + copysign, |
141 | 145 | cos, |
142 | 146 | cosh, |
143 | 147 | divide, |
|
148 | 152 | floor_divide, |
149 | 153 | greater, |
150 | 154 | greater_equal, |
| 155 | + hypot, |
151 | 156 | imag, |
152 | 157 | isfinite, |
153 | 158 | isinf, |
|
163 | 168 | logical_not, |
164 | 169 | logical_or, |
165 | 170 | logical_xor, |
| 171 | + maximum, |
| 172 | + minimum, |
166 | 173 | multiply, |
167 | 174 | negative, |
168 | 175 | not_equal, |
|
172 | 179 | remainder, |
173 | 180 | round, |
174 | 181 | sign, |
| 182 | + signbit, |
175 | 183 | sin, |
176 | 184 | sinh, |
177 | 185 | square, |
|
199 | 207 | "bitwise_right_shift", |
200 | 208 | "bitwise_xor", |
201 | 209 | "ceil", |
| 210 | + "clip", |
202 | 211 | "conj", |
| 212 | + "copysign", |
203 | 213 | "cos", |
204 | 214 | "cosh", |
205 | 215 | "divide", |
|
210 | 220 | "floor_divide", |
211 | 221 | "greater", |
212 | 222 | "greater_equal", |
| 223 | + "hypot", |
213 | 224 | "imag", |
214 | 225 | "isfinite", |
215 | 226 | "isinf", |
|
225 | 236 | "logical_not", |
226 | 237 | "logical_or", |
227 | 238 | "logical_xor", |
| 239 | + "maximum", |
| 240 | + "minimum", |
228 | 241 | "multiply", |
229 | 242 | "negative", |
230 | 243 | "not_equal", |
|
234 | 247 | "remainder", |
235 | 248 | "round", |
236 | 249 | "sign", |
| 250 | + "signbit", |
237 | 251 | "sin", |
238 | 252 | "sinh", |
239 | 253 | "square", |
|
248 | 262 |
|
249 | 263 | __all__ += ["take"] |
250 | 264 |
|
251 | | -# linalg is an extension in the array API spec, which is a sub-namespace. Only |
252 | | -# a subset of functions in it are imported into the top-level namespace. |
253 | | -from . import linalg |
| 265 | +from ._info import __array_namespace_info__ |
254 | 266 |
|
255 | | -__all__ += ["linalg"] |
| 267 | +__all__ += [ |
| 268 | + "__array_namespace_info__", |
| 269 | +] |
256 | 270 |
|
257 | 271 | from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot |
258 | 272 |
|
259 | 273 | __all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"] |
260 | 274 |
|
261 | | -from . import fft |
262 | | -__all__ += ["fft"] |
263 | | - |
264 | 275 | from ._manipulation_functions import ( |
265 | 276 | concat, |
266 | 277 | expand_dims, |
267 | 278 | flip, |
| 279 | + moveaxis, |
268 | 280 | permute_dims, |
| 281 | + repeat, |
269 | 282 | reshape, |
270 | 283 | roll, |
271 | 284 | squeeze, |
272 | 285 | stack, |
| 286 | + tile, |
| 287 | + unstack, |
273 | 288 | ) |
274 | 289 |
|
275 | | -__all__ += ["concat", "expand_dims", "flip", "permute_dims", "reshape", "roll", "squeeze", "stack"] |
| 290 | +__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack", "tile", "unstack"] |
276 | 291 |
|
277 | | -from ._searching_functions import argmax, argmin, nonzero, where |
| 292 | +from ._searching_functions import argmax, argmin, nonzero, searchsorted, where |
278 | 293 |
|
279 | | -__all__ += ["argmax", "argmin", "nonzero", "where"] |
| 294 | +__all__ += ["argmax", "argmin", "nonzero", "searchsorted", "where"] |
280 | 295 |
|
281 | 296 | from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values |
282 | 297 |
|
|
286 | 301 |
|
287 | 302 | __all__ += ["argsort", "sort"] |
288 | 303 |
|
289 | | -from ._statistical_functions import max, mean, min, prod, std, sum, var |
| 304 | +from ._statistical_functions import cumulative_sum, max, mean, min, prod, std, sum, var |
290 | 305 |
|
291 | | -__all__ += ["max", "mean", "min", "prod", "std", "sum", "var"] |
| 306 | +__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"] |
292 | 307 |
|
293 | 308 | from ._utility_functions import all, any |
294 | 309 |
|
|
308 | 323 | from . import _version |
309 | 324 | __version__ = _version.get_versions()['version'] |
310 | 325 | del _version |
| 326 | + |
| 327 | + |
| 328 | +# Extensions can be enabled or disabled dynamically. In order to make |
| 329 | +# "array_api_strict.linalg" give an AttributeError when it is disabled, we |
| 330 | +# use __getattr__. Note that linalg and fft are dynamically added and removed |
| 331 | +# from __all__ in set_array_api_strict_flags. |
| 332 | + |
| 333 | +def __getattr__(name): |
| 334 | + if name in ['linalg', 'fft']: |
| 335 | + if name in get_array_api_strict_flags()['enabled_extensions']: |
| 336 | + if name == 'linalg': |
| 337 | + from . import _linalg |
| 338 | + return _linalg |
| 339 | + elif name == 'fft': |
| 340 | + from . import _fft |
| 341 | + return _fft |
| 342 | + else: |
| 343 | + raise AttributeError(f"The {name!r} extension has been disabled for array_api_strict") |
| 344 | + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") |
0 commit comments