77"""
88from __future__ import annotations
99
10- import operator
1110from typing import TYPE_CHECKING
1211
1312if TYPE_CHECKING :
14- from typing import Callable , Literal , Optional , Union , Any
13+ from typing import Optional , Union , Any
1514 from ._typing import Array , Device
1615
1716import sys
@@ -92,7 +91,7 @@ def is_cupy_array(x):
9291 import cupy as cp
9392
9493 # TODO: Should we reject ndarray subclasses?
95- return isinstance (x , cp .ndarray )
94+ return isinstance (x , ( cp .ndarray , cp . generic ) )
9695
9796def is_torch_array (x ):
9897 """
@@ -788,7 +787,6 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
788787 return x
789788 return x .to_device (device , stream = stream )
790789
791-
792790def size (x ):
793791 """
794792 Return the total number of elements of x.
@@ -803,253 +801,6 @@ def size(x):
803801 return None
804802 return math .prod (x .shape )
805803
806-
807- def is_writeable_array (x ) -> bool :
808- """
809- Return False if ``x.__setitem__`` is expected to raise; True otherwise
810- """
811- if is_numpy_array (x ):
812- return x .flags .writeable
813- if is_jax_array (x ) or is_pydata_sparse_array (x ):
814- return False
815- return True
816-
817-
818- def _is_fancy_index (idx ) -> bool :
819- if not isinstance (idx , tuple ):
820- idx = (idx ,)
821- return any (
822- isinstance (i , (list , tuple )) or is_array_api_obj (i )
823- for i in idx
824- )
825-
826-
827- _undef = object ()
828-
829-
830- class at :
831- """
832- Update operations for read-only arrays.
833-
834- This implements ``jax.numpy.ndarray.at`` for all backends.
835-
836- Keyword arguments are passed verbatim to backends that support the `ndarray.at`
837- method; e.g. you may pass ``indices_are_sorted=True`` to JAX; they are quietly
838- ignored for backends that don't support them.
839-
840- Additionally, this introduces support for the `copy` keyword for all backends:
841-
842- None
843- The array parameter *may* be modified in place if it is possible and beneficial
844- for performance. You should not reuse it after calling this function.
845- True
846- Ensure that the inputs are not modified. This is the default.
847- False
848- Raise ValueError if a copy cannot be avoided.
849-
850- Examples
851- --------
852- Given either of these equivalent expressions::
853-
854- x = at(x)[1].add(2, copy=None)
855- x = at(x, 1).add(2, copy=None)
856-
857- If x is a JAX array, they are the same as::
858-
859- x = x.at[1].add(2)
860-
861- If x is a read-only numpy array, they are the same as::
862-
863- x = x.copy()
864- x[1] += 2
865-
866- Otherwise, they are the same as::
867-
868- x[1] += 2
869-
870- Warning
871- -------
872- When you use copy=None, you should always immediately overwrite
873- the parameter array::
874-
875- x = at(x, 0).set(2, copy=None)
876-
877- The anti-pattern below must be avoided, as it will result in different behaviour
878- on read-only versus writeable arrays::
879-
880- x = xp.asarray([0, 0, 0])
881- y = at(x, 0).set(2, copy=None)
882- z = at(x, 1).set(3, copy=None)
883-
884- In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
885- when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable!
886-
887- Warning
888- -------
889- The behaviour of update methods when the index is an array of integers which
890- contains multiple occurrences of the same index is undefined;
891- e.g. ``at(x, [0, 0]).set(2)``
892-
893- Note
894- ----
895- `sparse <https://sparse.pydata.org/>`_ is not supported by update methods yet.
896-
897- See Also
898- --------
899- `jax.numpy.ndarray.at <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_
900- """
901-
902- __slots__ = ("x" , "idx" )
903-
904- def __init__ (self , x , idx = _undef , / ):
905- self .x = x
906- self .idx = idx
907-
908- def __getitem__ (self , idx ):
909- """
910- Allow for the alternate syntax ``at(x)[start:stop:step]``,
911- which looks prettier than ``at(x, slice(start, stop, step))``
912- and feels more intuitive coming from the JAX documentation.
913- """
914- if self .idx is not _undef :
915- raise ValueError ("Index has already been set" )
916- self .idx = idx
917- return self
918-
919- def _common (
920- self ,
921- at_op : str ,
922- y = _undef ,
923- copy : bool | None | Literal ["_force_false" ] = True ,
924- ** kwargs ,
925- ):
926- """Perform common prepocessing.
927-
928- Returns
929- -------
930- If the operation can be resolved by at[], (return value, None)
931- Otherwise, (None, preprocessed x)
932- """
933- if self .idx is _undef :
934- raise TypeError (
935- "Index has not been set.\n "
936- "Usage: either\n "
937- " at(x, idx).set(value)\n "
938- "or\n "
939- " at(x)[idx].set(value)\n "
940- "(same for all other methods)."
941- )
942-
943- x = self .x
944-
945- if copy is False :
946- if not is_writeable_array (x ) or is_dask_array (x ):
947- raise ValueError ("Cannot modify parameter in place" )
948- elif copy is None :
949- copy = not is_writeable_array (x )
950- elif copy == "_force_false" :
951- copy = False
952- elif copy is not True :
953- raise ValueError (f"Invalid value for copy: { copy !r} " )
954-
955- if is_jax_array (x ):
956- # Use JAX's at[]
957- at_ = x .at [self .idx ]
958- args = (y ,) if y is not _undef else ()
959- return getattr (at_ , at_op )(* args , ** kwargs ), None
960-
961- # Emulate at[] behaviour for non-JAX arrays
962- if copy :
963- # FIXME We blindly expect the output of x.copy() to be always writeable.
964- # This holds true for read-only numpy arrays, but not necessarily for
965- # other backends.
966- xp = array_namespace (x )
967- x = xp .asarray (x , copy = True )
968-
969- return None , x
970-
971- def get (self , ** kwargs ):
972- """
973- Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
974- that the output is either a copy or a view; it also allows passing
975- keyword arguments to the backend.
976- """
977- # __getitem__ with a fancy index always returns a copy.
978- # Avoid an unnecessary double copy.
979- # If copy is forced to False, raise.
980- if _is_fancy_index (self .idx ):
981- if kwargs .get ("copy" , True ) is False :
982- raise TypeError (
983- "Indexing a numpy array with a fancy index always "
984- "results in a copy"
985- )
986- # Skip copy inside _common, even if array is not writeable
987- kwargs ["copy" ] = "_force_false"
988-
989- res , x = self ._common ("get" , ** kwargs )
990- if res is not None :
991- return res
992- return x [self .idx ]
993-
994- def set (self , y , / , ** kwargs ):
995- """Apply ``x[idx] = y`` and return the update array"""
996- res , x = self ._common ("set" , y , ** kwargs )
997- if res is not None :
998- return res
999- x [self .idx ] = y
1000- return x
1001-
1002- def _iop (
1003- self , at_op : str , elwise_op : Callable [[Array , Array ], Array ], y : Array , ** kwargs
1004- ):
1005- """x[idx] += y or equivalent in-place operation on a subset of x
1006-
1007- which is the same as saying
1008- x[idx] = x[idx] + y
1009- Note that this is not the same as
1010- operator.iadd(x[idx], y)
1011- Consider for example when x is a numpy array and idx is a fancy index, which
1012- triggers a deep copy on __getitem__.
1013- """
1014- res , x = self ._common (at_op , y , ** kwargs )
1015- if res is not None :
1016- return res
1017- x [self .idx ] = elwise_op (x [self .idx ], y )
1018- return x
1019-
1020- def add (self , y , / , ** kwargs ):
1021- """Apply ``x[idx] += y`` and return the updated array"""
1022- return self ._iop ("add" , operator .add , y , ** kwargs )
1023-
1024- def subtract (self , y , / , ** kwargs ):
1025- """Apply ``x[idx] -= y`` and return the updated array"""
1026- return self ._iop ("subtract" , operator .sub , y , ** kwargs )
1027-
1028- def multiply (self , y , / , ** kwargs ):
1029- """Apply ``x[idx] *= y`` and return the updated array"""
1030- return self ._iop ("multiply" , operator .mul , y , ** kwargs )
1031-
1032- def divide (self , y , / , ** kwargs ):
1033- """Apply ``x[idx] /= y`` and return the updated array"""
1034- return self ._iop ("divide" , operator .truediv , y , ** kwargs )
1035-
1036- def power (self , y , / , ** kwargs ):
1037- """Apply ``x[idx] **= y`` and return the updated array"""
1038- return self ._iop ("power" , operator .pow , y , ** kwargs )
1039-
1040- def min (self , y , / , ** kwargs ):
1041- """Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
1042- xp = array_namespace (self .x )
1043- y = xp .asarray (y )
1044- return self ._iop ("min" , xp .minimum , y , ** kwargs )
1045-
1046- def max (self , y , / , ** kwargs ):
1047- """Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
1048- xp = array_namespace (self .x )
1049- y = xp .asarray (y )
1050- return self ._iop ("max" , xp .maximum , y , ** kwargs )
1051-
1052-
1053804__all__ = [
1054805 "array_namespace" ,
1055806 "device" ,
@@ -1070,10 +821,8 @@ def max(self, y, /, **kwargs):
1070821 "is_ndonnx_namespace" ,
1071822 "is_pydata_sparse_array" ,
1072823 "is_pydata_sparse_namespace" ,
1073- "is_writeable_array" ,
1074824 "size" ,
1075825 "to_device" ,
1076- "at" ,
1077826]
1078827
1079- _all_ignore = ['inspect ' , 'math' , 'operator ' , 'warnings' , 'sys ' ]
828+ _all_ignore = ['sys ' , 'math' , 'inspect ' , 'warnings' ]
0 commit comments