|
8 | 8 |
|
9 | 9 | from __future__ import annotations |
10 | 10 |
|
| 11 | +import enum |
11 | 12 | import inspect |
12 | 13 | import math |
13 | 14 | import sys |
@@ -481,6 +482,86 @@ def _check_api_version(api_version: str | None) -> None: |
481 | 482 | ) |
482 | 483 |
|
483 | 484 |
|
| 485 | +class _ClsToXPInfo(enum.Enum): |
| 486 | + SCALAR = 0 |
| 487 | + MAYBE_JAX_ZERO_GRADIENT = 1 |
| 488 | + |
| 489 | + |
| 490 | +@lru_cache(100) |
| 491 | +def _cls_to_namespace( |
| 492 | + cls: type, |
| 493 | + api_version: str | None, |
| 494 | + use_compat: bool | None, |
| 495 | +) -> tuple[Namespace | None, _ClsToXPInfo | None]: |
| 496 | + if use_compat not in (None, True, False): |
| 497 | + raise ValueError("use_compat must be None, True, or False") |
| 498 | + _use_compat = use_compat in (None, True) |
| 499 | + cls_ = cast(Hashable, cls) # Make mypy happy |
| 500 | + |
| 501 | + if ( |
| 502 | + _issubclass_fast(cls_, "numpy", "ndarray") |
| 503 | + or _issubclass_fast(cls_, "numpy", "generic") |
| 504 | + ): |
| 505 | + if use_compat is True: |
| 506 | + _check_api_version(api_version) |
| 507 | + from .. import numpy as xp |
| 508 | + elif use_compat is False: |
| 509 | + import numpy as xp # type: ignore[no-redef] |
| 510 | + else: |
| 511 | + # NumPy 2.0+ have __array_namespace__; however they are not |
| 512 | + # yet fully array API compatible. |
| 513 | + from .. import numpy as xp # type: ignore[no-redef] |
| 514 | + return xp, _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT |
| 515 | + |
| 516 | + # Note: this must happen _after_ the test for np.generic, |
| 517 | + # because np.float64 and np.complex128 are subclasses of float and complex. |
| 518 | + if issubclass(cls, int | float | complex | type(None)): |
| 519 | + return None, _ClsToXPInfo.SCALAR |
| 520 | + |
| 521 | + if _issubclass_fast(cls_, "cupy", "ndarray"): |
| 522 | + if _use_compat: |
| 523 | + _check_api_version(api_version) |
| 524 | + from .. import cupy as xp # type: ignore[no-redef] |
| 525 | + else: |
| 526 | + import cupy as xp # type: ignore[no-redef] |
| 527 | + return xp, None |
| 528 | + |
| 529 | + if _issubclass_fast(cls_, "torch", "Tensor"): |
| 530 | + if _use_compat: |
| 531 | + _check_api_version(api_version) |
| 532 | + from .. import torch as xp # type: ignore[no-redef] |
| 533 | + else: |
| 534 | + import torch as xp # type: ignore[no-redef] |
| 535 | + return xp, None |
| 536 | + |
| 537 | + if _issubclass_fast(cls_, "dask.array", "Array"): |
| 538 | + if _use_compat: |
| 539 | + _check_api_version(api_version) |
| 540 | + from ..dask import array as xp # type: ignore[no-redef] |
| 541 | + else: |
| 542 | + import dask.array as xp # type: ignore[no-redef] |
| 543 | + return xp, None |
| 544 | + |
| 545 | + # Backwards compatibility for jax<0.4.32 |
| 546 | + if _issubclass_fast(cls_, "jax", "Array"): |
| 547 | + return _jax_namespace(api_version, use_compat), None |
| 548 | + |
| 549 | + return None, None |
| 550 | + |
| 551 | + |
| 552 | +def _jax_namespace(api_version: str | None, use_compat: bool | None) -> Namespace: |
| 553 | + if use_compat: |
| 554 | + raise ValueError("JAX does not have an array-api-compat wrapper") |
| 555 | + import jax.numpy as jnp |
| 556 | + if not hasattr(jnp, "__array_namespace_info__"): |
| 557 | + # JAX v0.4.32 and newer implements the array API directly in jax.numpy. |
| 558 | + # For older JAX versions, it is available via jax.experimental.array_api. |
| 559 | + # jnp.Array objects gain the __array_namespace__ method. |
| 560 | + import jax.experimental.array_api # noqa: F401 |
| 561 | + # Test api_version |
| 562 | + return jnp.empty(0).__array_namespace__(api_version=api_version) |
| 563 | + |
| 564 | + |
484 | 565 | def array_namespace( |
485 | 566 | *xs: Array | complex | None, |
486 | 567 | api_version: str | None = None, |
@@ -549,105 +630,40 @@ def your_function(x, y): |
549 | 630 | is_pydata_sparse_array |
550 | 631 |
|
551 | 632 | """ |
552 | | - if use_compat not in [None, True, False]: |
553 | | - raise ValueError("use_compat must be None, True, or False") |
554 | | - |
555 | | - _use_compat = use_compat in [None, True] |
556 | | - |
557 | 633 | namespaces: set[Namespace] = set() |
558 | 634 | for x in xs: |
559 | | - if is_numpy_array(x): |
560 | | - import numpy as np |
561 | | - |
562 | | - from .. import numpy as numpy_namespace |
563 | | - |
564 | | - if use_compat is True: |
565 | | - _check_api_version(api_version) |
566 | | - namespaces.add(numpy_namespace) |
567 | | - elif use_compat is False: |
568 | | - namespaces.add(np) |
569 | | - else: |
570 | | - # numpy 2.0+ have __array_namespace__, however, they are not yet fully array API |
571 | | - # compatible. |
572 | | - namespaces.add(numpy_namespace) |
573 | | - elif is_cupy_array(x): |
574 | | - if _use_compat: |
575 | | - _check_api_version(api_version) |
576 | | - from .. import cupy as cupy_namespace |
577 | | - |
578 | | - namespaces.add(cupy_namespace) |
579 | | - else: |
580 | | - import cupy as cp |
581 | | - |
582 | | - namespaces.add(cp) |
583 | | - elif is_torch_array(x): |
584 | | - if _use_compat: |
585 | | - _check_api_version(api_version) |
586 | | - from .. import torch as torch_namespace |
587 | | - |
588 | | - namespaces.add(torch_namespace) |
589 | | - else: |
590 | | - import torch |
591 | | - |
592 | | - namespaces.add(torch) |
593 | | - elif is_dask_array(x): |
594 | | - if _use_compat: |
595 | | - _check_api_version(api_version) |
596 | | - from ..dask import array as dask_namespace |
597 | | - |
598 | | - namespaces.add(dask_namespace) |
599 | | - else: |
600 | | - import dask.array as da |
601 | | - |
602 | | - namespaces.add(da) |
603 | | - elif is_jax_array(x): |
604 | | - if use_compat is True: |
605 | | - _check_api_version(api_version) |
606 | | - raise ValueError("JAX does not have an array-api-compat wrapper") |
607 | | - elif use_compat is False: |
608 | | - import jax.numpy as jnp |
609 | | - else: |
610 | | - # JAX v0.4.32 and newer implements the array API directly in jax.numpy. |
611 | | - # For older JAX versions, it is available via jax.experimental.array_api. |
612 | | - import jax.numpy |
613 | | - |
614 | | - if hasattr(jax.numpy, "__array_api_version__"): |
615 | | - jnp = jax.numpy |
616 | | - else: |
617 | | - import jax.experimental.array_api as jnp # type: ignore[no-redef] |
618 | | - namespaces.add(jnp) |
619 | | - elif is_pydata_sparse_array(x): |
620 | | - if use_compat is True: |
621 | | - _check_api_version(api_version) |
622 | | - raise ValueError("`sparse` does not have an array-api-compat wrapper") |
623 | | - else: |
624 | | - import sparse |
625 | | - # `sparse` is already an array namespace. We do not have a wrapper |
626 | | - # submodule for it. |
627 | | - namespaces.add(sparse) |
628 | | - elif hasattr(x, "__array_namespace__"): |
629 | | - if use_compat is True: |
| 635 | + xp, info = _cls_to_namespace(cast(Hashable, type(x)), api_version, use_compat) |
| 636 | + if info is _ClsToXPInfo.SCALAR: |
| 637 | + continue |
| 638 | + |
| 639 | + if ( |
| 640 | + info is _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT |
| 641 | + and _is_jax_zero_gradient_array(x) |
| 642 | + ): |
| 643 | + xp = _jax_namespace(api_version, use_compat) |
| 644 | + |
| 645 | + if xp is None: |
| 646 | + get_ns = getattr(x, "__array_namespace__", None) |
| 647 | + if get_ns is None: |
| 648 | + raise TypeError(f"{type(x).__name__} is not a supported array type") |
| 649 | + if use_compat: |
630 | 650 | raise ValueError( |
631 | 651 | "The given array does not have an array-api-compat wrapper" |
632 | 652 | ) |
633 | | - x = cast("SupportsArrayNamespace[Any]", x) |
634 | | - namespaces.add(x.__array_namespace__(api_version=api_version)) |
635 | | - elif isinstance(x, int | float | complex) or x is None: |
636 | | - continue |
637 | | - else: |
638 | | - # TODO: Support Python scalars? |
639 | | - raise TypeError(f"{type(x).__name__} is not a supported array type") |
| 653 | + xp = get_ns(api_version=api_version) |
640 | 654 |
|
641 | | - if not namespaces: |
642 | | - raise TypeError("Unrecognized array input") |
| 655 | + namespaces.add(xp) |
643 | 656 |
|
644 | | - if len(namespaces) != 1: |
| 657 | + try: |
| 658 | + (xp,) = namespaces |
| 659 | + return xp |
| 660 | + except ValueError: |
| 661 | + if not namespaces: |
| 662 | + raise TypeError( |
| 663 | + "array_namespace requires at least one non-scalar array input" |
| 664 | + ) |
645 | 665 | raise TypeError(f"Multiple namespaces for array inputs: {namespaces}") |
646 | 666 |
|
647 | | - (xp,) = namespaces |
648 | | - |
649 | | - return xp |
650 | | - |
651 | 667 |
|
652 | 668 | # backwards compatibility alias |
653 | 669 | get_namespace = array_namespace |
|
0 commit comments