11from __future__ import annotations
22
33from functools import wraps
4- from builtins import all as builtin_all
4+ from builtins import all as builtin_all , any as builtin_any
55
66from ..common ._aliases import (UniqueAllResult , UniqueCountsResult ,
77 UniqueInverseResult ,
1818import torch
1919array = torch .Tensor
2020
21- _array_api_dtypes = {
22- torch .bool ,
21+ _int_dtypes = {
2322 torch .uint8 ,
2423 torch .int8 ,
2524 torch .int16 ,
2625 torch .int32 ,
2726 torch .int64 ,
27+ }
28+
29+ _array_api_dtypes = {
30+ torch .bool ,
31+ * _int_dtypes ,
2832 torch .float32 ,
2933 torch .float64 ,
3034}
@@ -602,6 +606,43 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
602606 x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
603607 return torch .tensordot (x1 , x2 , dims = axes , ** kwargs )
604608
609+
610+ def isdtype (
611+ dtype : Dtype , kind : Union [Dtype , str , Tuple [Union [Dtype , str ], ...]],
612+ * , _tuple = True , # Disallow nested tuples
613+ ) -> bool :
614+ """
615+ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
616+
617+ Note that outside of this function, this compat library does not yet fully
618+ support complex numbers.
619+
620+ See
621+ https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
622+ for more details
623+ """
624+ if isinstance (kind , tuple ) and _tuple :
625+ return builtin_any (isdtype (dtype , k , _tuple = False ) for k in kind )
626+ elif isinstance (kind , str ):
627+ if kind == 'bool' :
628+ return dtype == torch .bool
629+ elif kind == 'signed integer' :
630+ return dtype in _int_dtypes and dtype .is_signed
631+ elif kind == 'unsigned integer' :
632+ return dtype in _int_dtypes and not dtype .is_signed
633+ elif kind == 'integral' :
634+ return dtype in _int_dtypes
635+ elif kind == 'real floating' :
636+ return dtype .is_floating_point
637+ elif kind == 'complex floating' :
638+ return dtype .is_complex
639+ elif kind == 'numeric' :
640+ return isdtype (dtype , ('integral' , 'real floating' , 'complex floating' ))
641+ else :
642+ raise ValueError (f"Unrecognized data type kind: { kind !r} " )
643+ else :
644+ return dtype == kind
645+
605646__all__ = ['result_type' , 'can_cast' , 'permute_dims' , 'bitwise_invert' , 'add' ,
606647 'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
607648 'bitwise_right_shift' , 'bitwise_xor' , 'divide' , 'equal' ,
@@ -612,4 +653,4 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
612653 'nonzero' , 'where' , 'arange' , 'eye' , 'linspace' , 'full' , 'ones' ,
613654 'zeros' , 'empty' , 'expand_dims' , 'astype' , 'broadcast_arrays' ,
614655 'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
615- 'matmul' , 'matrix_transpose' , 'vecdot' , 'tensordot' ]
656+ 'matmul' , 'matrix_transpose' , 'vecdot' , 'tensordot' , 'isdtype' ]
0 commit comments