@@ -839,6 +839,19 @@ def size(x: Array) -> int | None:
839839 return None if math .isnan (out ) else out
840840
841841
842+ @cache
843+ def _is_writeable_cls (cls : type ) -> bool | None :
844+ if (
845+ _issubclass_fast (cls , "numpy" , "generic" )
846+ or _issubclass_fast (cls , "jax" , "Array" )
847+ or _issubclass_fast (cls , "sparse" , "SparseArray" )
848+ ):
849+ return False
850+ if _is_array_api_cls (cls ):
851+ return True
852+ return None
853+
854+
842855def is_writeable_array (x : object ) -> bool :
843856 """
844857 Return False if ``x.__setitem__`` is expected to raise; True otherwise.
@@ -849,11 +862,32 @@ def is_writeable_array(x: object) -> bool:
849862 As there is no standard way to check if an array is writeable without actually
850863 writing to it, this function blindly returns True for all unknown array types.
851864 """
852- if is_numpy_array (x ):
865+ cls = type (x )
866+ if _issubclass_fast (cls , "numpy" , "ndarray" ):
853867 return x .flags .writeable
854- if is_jax_array (x ) or is_pydata_sparse_array (x ):
868+ res = _is_writeable_cls (cls )
869+ if res is not None :
870+ return res
871+ return hasattr (x , '__array_namespace__' )
872+
873+
874+ @cache
875+ def _is_lazy_cls (cls : type ) -> bool | None :
876+ if (
877+ _issubclass_fast (cls , "numpy" , "ndarray" )
878+ or _issubclass_fast (cls , "numpy" , "generic" )
879+ or _issubclass_fast (cls , "cupy" , "ndarray" )
880+ or _issubclass_fast (cls , "torch" , "Tensor" )
881+ or _issubclass_fast (cls , "sparse" , "SparseArray" )
882+ ):
855883 return False
856- return is_array_api_obj (x )
884+ if (
885+ _issubclass_fast (cls , "jax" , "Array" )
886+ or _issubclass_fast (cls , "dask.array" , "Array" )
887+ or _issubclass_fast (cls , "ndonnx" , "Array" )
888+ ):
889+ return True
890+ return None
857891
858892
859893def is_lazy_array (x : object ) -> bool :
@@ -869,14 +903,6 @@ def is_lazy_array(x: object) -> bool:
869903 This function errs on the side of caution for array types that may or may not be
870904 lazy, e.g. JAX arrays, by always returning True for them.
871905 """
872- if (
873- is_numpy_array (x )
874- or is_cupy_array (x )
875- or is_torch_array (x )
876- or is_pydata_sparse_array (x )
877- ):
878- return False
879-
880906 # **JAX note:** while it is possible to determine if you're inside or outside
881907 # jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
882908 # as we do below for unknown arrays, this is not recommended by JAX best practices.
@@ -886,10 +912,13 @@ def is_lazy_array(x: object) -> bool:
886912 # compatibility, is highly detrimental to performance as the whole graph will end
887913 # up being computed multiple times.
888914
889- if is_jax_array (x ) or is_dask_array (x ) or is_ndonnx_array (x ):
890- return True
915+ # Note: skipping reclassification of JAX zero gradient arrays, as one will
916+ # exclusively get them once they leave a jax.grad JIT context.
917+ res = _is_lazy_cls (type (x ))
918+ if res is not None :
919+ return res
891920
892- if not is_array_api_obj ( x ):
921+ if not hasattr ( x , "__array_namespace__" ):
893922 return False
894923
895924 # Unknown Array API compatible object. Note that this test may have dire consequences
0 commit comments