@@ -788,6 +788,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
788788 return x
789789 return x .to_device (device , stream = stream )
790790
791+
791792def size (x ):
792793 """
793794 Return the total number of elements of x.
@@ -802,6 +803,7 @@ def size(x):
802803 return None
803804 return math .prod (x .shape )
804805
806+
805807def is_writeable_array (x ):
806808 """
807809 Return False if x.__setitem__ is expected to raise; True otherwise
@@ -812,6 +814,7 @@ def is_writeable_array(x):
812814 return False
813815 return True
814816
817+
815818def _parse_copy_param (x , copy : bool | None | Literal ["_force_false" ]) -> bool :
816819 """Preprocess and validate a copy parameter, in line with the same
817820 parameter in np.asarray(), np.astype(), etc.
@@ -827,8 +830,10 @@ def _parse_copy_param(x, copy: bool | None | Literal["_force_false"]) -> bool:
827830 raise ValueError (f"Invalid value for copy: { copy !r} " )
828831 return copy
829832
833+
830834_undef = object ()
831835
836+
832837class at :
833838 """
834839 Update operations for read-only arrays.
@@ -897,6 +902,7 @@ class at:
897902 --------
898903 https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
899904 """
905+
900906 __slots__ = ("x" , "idx" )
901907
902908 def __init__ (self , x , idx = _undef ):
@@ -945,7 +951,7 @@ def _common(
945951 if copy and is_jax_array (self .x ):
946952 # Use JAX's at[]
947953 at_ = self .x .at [self .idx ]
948- args = (y , ) if y is not _undef else ()
954+ args = (y ,) if y is not _undef else ()
949955 return getattr (at_ , at_op )(* args , ** kwargs ), None
950956
951957 # Emulate at[] behaviour for non-JAX arrays
@@ -958,12 +964,9 @@ def get(self, copy: bool | None = True, **kwargs):
958964 # Special case when xp=numpy and idx is a fancy index
959965 # If copy is not False, avoid an unnecessary double copy.
960966 # if copy is forced to False, raise.
961- if (
962- is_numpy_array (self .x )
963- and (
964- isinstance (self .idx , (list , tuple ))
965- or (is_numpy_array (self .idx ) and self .idx .dtype .kind in "biu" )
966- )
967+ if is_numpy_array (self .x ) and (
968+ isinstance (self .idx , (list , tuple ))
969+ or (is_numpy_array (self .idx ) and self .idx .dtype .kind in "biu" )
967970 ):
968971 if copy is False :
969972 raise ValueError (
@@ -994,12 +997,14 @@ def apply(self, ufunc, /, **kwargs):
994997 ufunc .at (x , self .idx )
995998 return x
996999
997- def _iop (self , at_op : str , elwise_op : Callable [[Array , Array ], Array ], y : Array , ** kwargs ):
1000+ def _iop (
1001+ self , at_op : str , elwise_op : Callable [[Array , Array ], Array ], y : Array , ** kwargs
1002+ ):
9981003 """x[idx] += y or equivalent in-place operation on a subset of x
9991004
10001005 which is the same as saying
10011006 x[idx] = x[idx] + y
1002- Note that this is not the same as
1007+ Note that this is not the same as
10031008 operator.iadd(x[idx], y)
10041009 Consider for example when x is a numpy array and idx is a fancy index, which
10051010 triggers a deep copy on __getitem__.
@@ -1017,11 +1022,11 @@ def add(self, y, /, **kwargs):
10171022 def subtract (self , y , / , ** kwargs ):
10181023 """x[idx] -= y"""
10191024 return self ._iop ("subtract" , operator .sub , y , ** kwargs )
1020-
1025+
10211026 def multiply (self , y , / , ** kwargs ):
10221027 """x[idx] *= y"""
10231028 return self ._iop ("multiply" , operator .mul , y , ** kwargs )
1024-
1029+
10251030 def divide (self , y , / , ** kwargs ):
10261031 """x[idx] /= y"""
10271032 return self ._iop ("divide" , operator .truediv , y , ** kwargs )
@@ -1040,9 +1045,10 @@ def max(self, y, /, **kwargs):
10401045 xp = array_namespace (self .x )
10411046 return self ._iop ("max" , xp .maximum , y , ** kwargs )
10421047
1048+
10431049def where (condition , x = None , y = None , / , copy : bool | None = True ):
10441050 """Return elements from x when condition is True and from y when
1045- it is False.
1051+ it is False.
10461052
10471053 This is a wrapper around xp.where that adds the copy parameter:
10481054
0 commit comments