77"""
88from __future__ import annotations
99
10+ import operator
1011from typing import TYPE_CHECKING
1112
1213if TYPE_CHECKING :
13- from typing import Optional , Union , Any
14+ from typing import Callable , Literal , Optional , Union , Any
1415 from ._typing import Array , Device
1516
1617import sys
@@ -91,7 +92,7 @@ def is_cupy_array(x):
9192 import cupy as cp
9293
9394 # TODO: Should we reject ndarray subclasses?
94- return isinstance (x , ( cp .ndarray , cp . generic ) )
95+ return isinstance (x , cp .ndarray )
9596
9697def is_torch_array (x ):
9798 """
@@ -787,6 +788,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
787788 return x
788789 return x .to_device (device , stream = stream )
789790
791+
790792def size (x ):
791793 """
792794 Return the total number of elements of x.
@@ -801,6 +803,253 @@ def size(x):
801803 return None
802804 return math .prod (x .shape )
803805
806+
807+ def is_writeable_array (x ) -> bool :
808+ """
809+ Return False if ``x.__setitem__`` is expected to raise; True otherwise
810+ """
811+ if is_numpy_array (x ):
812+ return x .flags .writeable
813+ if is_jax_array (x ) or is_pydata_sparse_array (x ):
814+ return False
815+ return True
816+
817+
818+ def _is_fancy_index (idx ) -> bool :
819+ if not isinstance (idx , tuple ):
820+ idx = (idx ,)
821+ return any (
822+ isinstance (i , (list , tuple )) or is_array_api_obj (i )
823+ for i in idx
824+ )
825+
826+
827+ _undef = object ()
828+
829+
830+ class at :
831+ """
832+ Update operations for read-only arrays.
833+
834+ This implements ``jax.numpy.ndarray.at`` for all backends.
835+
836+ Keyword arguments are passed verbatim to backends that support the `ndarray.at`
837+ method; e.g. you may pass ``indices_are_sorted=True`` to JAX; they are quietly
838+ ignored for backends that don't support them.
839+
840+ Additionally, this introduces support for the `copy` keyword for all backends:
841+
842+ None
843+ The array parameter *may* be modified in place if it is possible and beneficial
844+ for performance. You should not reuse it after calling this function.
845+ True
846+ Ensure that the inputs are not modified. This is the default.
847+ False
848+ Raise ValueError if a copy cannot be avoided.
849+
850+ Examples
851+ --------
852+ Given either of these equivalent expressions::
853+
854+ x = at(x)[1].add(2, copy=None)
855+ x = at(x, 1).add(2, copy=None)
856+
857+ If x is a JAX array, they are the same as::
858+
859+ x = x.at[1].add(2)
860+
861+ If x is a read-only numpy array, they are the same as::
862+
863+ x = x.copy()
864+ x[1] += 2
865+
866+ Otherwise, they are the same as::
867+
868+ x[1] += 2
869+
870+ Warning
871+ -------
872+ When you use copy=None, you should always immediately overwrite
873+ the parameter array::
874+
875+ x = at(x, 0).set(2, copy=None)
876+
877+ The anti-pattern below must be avoided, as it will result in different behaviour
878+ on read-only versus writeable arrays::
879+
880+ x = xp.asarray([0, 0, 0])
881+ y = at(x, 0).set(2, copy=None)
882+ z = at(x, 1).set(3, copy=None)
883+
884+ In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
885+ when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable!
886+
887+ Warning
888+ -------
889+ The behaviour of update methods when the index is an array of integers which
890+ contains multiple occurrences of the same index is undefined;
891+ e.g. ``at(x, [0, 0]).set(2)``
892+
893+ Note
894+ ----
895+ `sparse <https://sparse.pydata.org/>`_ is not supported by update methods yet.
896+
897+ See Also
898+ --------
899+ `jax.numpy.ndarray.at <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_
900+ """
901+
902+ __slots__ = ("x" , "idx" )
903+
904+ def __init__ (self , x , idx = _undef , / ):
905+ self .x = x
906+ self .idx = idx
907+
908+ def __getitem__ (self , idx ):
909+ """
910+ Allow for the alternate syntax ``at(x)[start:stop:step]``,
911+ which looks prettier than ``at(x, slice(start, stop, step))``
912+ and feels more intuitive coming from the JAX documentation.
913+ """
914+ if self .idx is not _undef :
915+ raise ValueError ("Index has already been set" )
916+ self .idx = idx
917+ return self
918+
919+ def _common (
920+ self ,
921+ at_op : str ,
922+ y = _undef ,
923+ copy : bool | None | Literal ["_force_false" ] = True ,
924+ ** kwargs ,
925+ ):
926+ """Perform common prepocessing.
927+
928+ Returns
929+ -------
930+ If the operation can be resolved by at[], (return value, None)
931+ Otherwise, (None, preprocessed x)
932+ """
933+ if self .idx is _undef :
934+ raise TypeError (
935+ "Index has not been set.\n "
936+ "Usage: either\n "
937+ " at(x, idx).set(value)\n "
938+ "or\n "
939+ " at(x)[idx].set(value)\n "
940+ "(same for all other methods)."
941+ )
942+
943+ x = self .x
944+
945+ if copy is False :
946+ if not is_writeable_array (x ) or is_dask_array (x ):
947+ raise ValueError ("Cannot modify parameter in place" )
948+ elif copy is None :
949+ copy = not is_writeable_array (x )
950+ elif copy == "_force_false" :
951+ copy = False
952+ elif copy is not True :
953+ raise ValueError (f"Invalid value for copy: { copy !r} " )
954+
955+ if is_jax_array (x ):
956+ # Use JAX's at[]
957+ at_ = x .at [self .idx ]
958+ args = (y ,) if y is not _undef else ()
959+ return getattr (at_ , at_op )(* args , ** kwargs ), None
960+
961+ # Emulate at[] behaviour for non-JAX arrays
962+ if copy :
963+ # FIXME We blindly expect the output of x.copy() to be always writeable.
964+ # This holds true for read-only numpy arrays, but not necessarily for
965+ # other backends.
966+ xp = array_namespace (x )
967+ x = xp .asarray (x , copy = True )
968+
969+ return None , x
970+
971+ def get (self , ** kwargs ):
972+ """
973+ Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
974+ that the output is either a copy or a view; it also allows passing
975+ keyword arguments to the backend.
976+ """
977+ # __getitem__ with a fancy index always returns a copy.
978+ # Avoid an unnecessary double copy.
979+ # If copy is forced to False, raise.
980+ if _is_fancy_index (self .idx ):
981+ if kwargs .get ("copy" , True ) is False :
982+ raise TypeError (
983+ "Indexing a numpy array with a fancy index always "
984+ "results in a copy"
985+ )
986+ # Skip copy inside _common, even if array is not writeable
987+ kwargs ["copy" ] = "_force_false" # type: ignore
988+
989+ res , x = self ._common ("get" , ** kwargs )
990+ if res is not None :
991+ return res
992+ return x [self .idx ]
993+
994+ def set (self , y , / , ** kwargs ):
995+ """Apply ``x[idx] = y`` and return the update array"""
996+ res , x = self ._common ("set" , y , ** kwargs )
997+ if res is not None :
998+ return res
999+ x [self .idx ] = y
1000+ return x
1001+
1002+ def _iop (
1003+ self , at_op : str , elwise_op : Callable [[Array , Array ], Array ], y : Array , ** kwargs
1004+ ):
1005+ """x[idx] += y or equivalent in-place operation on a subset of x
1006+
1007+ which is the same as saying
1008+ x[idx] = x[idx] + y
1009+ Note that this is not the same as
1010+ operator.iadd(x[idx], y)
1011+ Consider for example when x is a numpy array and idx is a fancy index, which
1012+ triggers a deep copy on __getitem__.
1013+ """
1014+ res , x = self ._common (at_op , y , ** kwargs )
1015+ if res is not None :
1016+ return res
1017+ x [self .idx ] = elwise_op (x [self .idx ], y )
1018+ return x
1019+
1020+ def add (self , y , / , ** kwargs ):
1021+ """Apply ``x[idx] += y`` and return the updated array"""
1022+ return self ._iop ("add" , operator .add , y , ** kwargs )
1023+
1024+ def subtract (self , y , / , ** kwargs ):
1025+ """Apply ``x[idx] -= y`` and return the updated array"""
1026+ return self ._iop ("subtract" , operator .sub , y , ** kwargs )
1027+
1028+ def multiply (self , y , / , ** kwargs ):
1029+ """Apply ``x[idx] *= y`` and return the updated array"""
1030+ return self ._iop ("multiply" , operator .mul , y , ** kwargs )
1031+
1032+ def divide (self , y , / , ** kwargs ):
1033+ """Apply ``x[idx] /= y`` and return the updated array"""
1034+ return self ._iop ("divide" , operator .truediv , y , ** kwargs )
1035+
1036+ def power (self , y , / , ** kwargs ):
1037+ """Apply ``x[idx] **= y`` and return the updated array"""
1038+ return self ._iop ("power" , operator .pow , y , ** kwargs )
1039+
1040+ def min (self , y , / , ** kwargs ):
1041+ """Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
1042+ xp = array_namespace (self .x )
1043+ y = xp .asarray (y )
1044+ return self ._iop ("min" , xp .minimum , y , ** kwargs )
1045+
1046+ def max (self , y , / , ** kwargs ):
1047+ """Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
1048+ xp = array_namespace (self .x )
1049+ y = xp .asarray (y )
1050+ return self ._iop ("max" , xp .maximum , y , ** kwargs )
1051+
1052+
8041053__all__ = [
8051054 "array_namespace" ,
8061055 "device" ,
@@ -821,8 +1070,10 @@ def size(x):
8211070 "is_ndonnx_namespace" ,
8221071 "is_pydata_sparse_array" ,
8231072 "is_pydata_sparse_namespace" ,
1073+ "is_writeable_array" ,
8241074 "size" ,
8251075 "to_device" ,
1076+ "at" ,
8261077]
8271078
828- _all_ignore = ['sys ' , 'math' , 'inspect ' , 'warnings' ]
1079+ _all_ignore = ['inspect ' , 'math' , 'operator ' , 'warnings' , 'sys ' ]
0 commit comments