1212import math
1313import sys
1414import warnings
15- from collections .abc import Collection
15+ from collections .abc import Collection , Hashable
1616from functools import lru_cache
1717from typing import (
1818 TYPE_CHECKING ,
@@ -83,7 +83,8 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
8383 dtype = x .dtype # type: ignore[attr-defined]
8484 except AttributeError :
8585 return False
86- if not _issubclass_fast (type (dtype ), "numpy.dtypes" , "VoidDType" ):
86+ cls = cast (Hashable , type (dtype ))
87+ if not _issubclass_fast (cls , "numpy.dtypes" , "VoidDType" ):
8788 return False
8889
8990 if "jax" not in sys .modules :
@@ -116,7 +117,7 @@ def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]:
116117 is_pydata_sparse_array
117118 """
118119 # TODO: Should we reject ndarray subclasses?
119- cls = type (x )
120+ cls = cast ( Hashable , type (x ) )
120121 return (
121122 _issubclass_fast (cls , "numpy" , "ndarray" )
122123 or _issubclass_fast (cls , "numpy" , "generic" )
@@ -144,7 +145,8 @@ def is_cupy_array(x: object) -> bool:
144145 is_jax_array
145146 is_pydata_sparse_array
146147 """
147- return _issubclass_fast (type (x ), "cupy" , "ndarray" )
148+ cls = cast (Hashable , type (x ))
149+ return _issubclass_fast (cls , "cupy" , "ndarray" )
148150
149151
150152def is_torch_array (x : object ) -> TypeIs [torch .Tensor ]:
@@ -165,7 +167,8 @@ def is_torch_array(x: object) -> TypeIs[torch.Tensor]:
165167 is_jax_array
166168 is_pydata_sparse_array
167169 """
168- return _issubclass_fast (type (x ), "torch" , "Tensor" )
170+ cls = cast (Hashable , type (x ))
171+ return _issubclass_fast (cls , "torch" , "Tensor" )
169172
170173
171174def is_ndonnx_array (x : object ) -> TypeIs [ndx .Array ]:
@@ -187,7 +190,8 @@ def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]:
187190 is_jax_array
188191 is_pydata_sparse_array
189192 """
190- return _issubclass_fast (type (x ), "ndonnx" , "Array" )
193+ cls = cast (Hashable , type (x ))
194+ return _issubclass_fast (cls , "ndonnx" , "Array" )
191195
192196
193197def is_dask_array (x : object ) -> TypeIs [da .Array ]:
@@ -209,7 +213,8 @@ def is_dask_array(x: object) -> TypeIs[da.Array]:
209213 is_jax_array
210214 is_pydata_sparse_array
211215 """
212- return _issubclass_fast (type (x ), "dask.array" , "Array" )
216+ cls = cast (Hashable , type (x ))
217+ return _issubclass_fast (cls , "dask.array" , "Array" )
213218
214219
215220def is_jax_array (x : object ) -> TypeIs [jax .Array ]:
@@ -232,7 +237,8 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]:
232237 is_dask_array
233238 is_pydata_sparse_array
234239 """
235- return _issubclass_fast (type (x ), "jax" , "Array" ) or _is_jax_zero_gradient_array (x )
240+ cls = cast (Hashable , type (x ))
241+ return _issubclass_fast (cls , "jax" , "Array" ) or _is_jax_zero_gradient_array (x )
236242
237243
238244def is_pydata_sparse_array (x : object ) -> TypeIs [sparse .SparseArray ]:
@@ -256,7 +262,8 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
256262 is_jax_array
257263 """
258264 # TODO: Account for other backends.
259- return _issubclass_fast (type (x ), "sparse" , "SparseArray" )
265+ cls = cast (Hashable , type (x ))
266+ return _issubclass_fast (cls , "sparse" , "SparseArray" )
260267
261268
262269def is_array_api_obj (x : object ) -> TypeIs [_ArrayApiObj ]: # pyright: ignore[reportUnknownParameterType]
@@ -274,7 +281,10 @@ def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[repo
274281 is_dask_array
275282 is_jax_array
276283 """
277- return hasattr (x , '__array_namespace__' ) or _is_array_api_cls (type (x ))
284+ return (
285+ hasattr (x , '__array_namespace__' )
286+ or _is_array_api_cls (cast (Hashable , type (x )))
287+ )
278288
279289
280290@lru_cache (100 )
@@ -946,9 +956,9 @@ def is_writeable_array(x: object) -> bool:
946956 As there is no standard way to check if an array is writeable without actually
947957 writing to it, this function blindly returns True for all unknown array types.
948958 """
949- cls = type (x )
959+ cls = cast ( Hashable , type (x ) )
950960 if _issubclass_fast (cls , "numpy" , "ndarray" ):
951- return x .flags .writeable
961+ return cast ( npt . NDArray , x ) .flags .writeable
952962 res = _is_writeable_cls (cls )
953963 if res is not None :
954964 return res
@@ -998,7 +1008,8 @@ def is_lazy_array(x: object) -> bool:
9981008
9991009 # Note: skipping reclassification of JAX zero gradient arrays, as one will
10001010 # exclusively get them once they leave a jax.grad JIT context.
1001- res = _is_lazy_cls (type (x ))
1011+ cls = cast (Hashable , type (x ))
1012+ res = _is_lazy_cls (cls )
10021013 if res is not None :
10031014 return res
10041015
0 commit comments