@@ -1740,41 +1740,30 @@ def make_node(self, x, y, *inputs):
17401740 def decl_view (self ):
17411741 return "PyArrayObject * zview = NULL;"
17421742
1743- def perform (self , node , inputs , out_ ):
1744- (out ,) = out_
1745- x , y = inputs [:2 ]
1746- indices = list (reversed (inputs [2 :]))
1747-
1748- def _convert (entry ):
1749- if isinstance (entry , Type ):
1750- return indices .pop ()
1751- elif isinstance (entry , slice ):
1752- return slice (
1753- _convert (entry .start ), _convert (entry .stop ), _convert (entry .step )
1743+ def perform (self , node , inputs , output_storage ):
1744+ x , y , * flat_indices = inputs
1745+
1746+ flat_indices_iterator = iter (flat_indices )
1747+ indices = tuple (
1748+ (
1749+ next (flat_indices_iterator )
1750+ if isinstance (entry , Type )
1751+ else slice (
1752+ None if entry .start is None else next (flat_indices_iterator ),
1753+ None if entry .stop is None else next (flat_indices_iterator ),
1754+ None if entry .step is None else next (flat_indices_iterator ),
17541755 )
1755- else :
1756- return entry
1756+ )
1757+ for entry in self .idx_list
1758+ )
17571759
1758- cdata = tuple (map (_convert , self .idx_list ))
1759- if len (cdata ) == 1 :
1760- cdata = cdata [0 ]
17611760 if not self .inplace :
17621761 x = x .copy ()
1763- sub_x = x .__getitem__ (cdata )
1764- if sub_x .shape :
1765- # we've sliced out an N-D tensor with N > 0
1766- if not self .set_instead_of_inc :
1767- sub_x += y
1768- else :
1769- # sub_x += -sub_x + y
1770- x .__setitem__ (cdata , y )
1762+ if self .set_instead_of_inc :
1763+ x [indices ] = y
17711764 else :
1772- # scalar case
1773- if not self .set_instead_of_inc :
1774- x .__setitem__ (cdata , sub_x + y )
1775- else :
1776- x .__setitem__ (cdata , y )
1777- out [0 ] = x
1765+ x [indices ] += y
1766+ output_storage [0 ][0 ] = x
17781767
17791768 def c_code (self , node , name , inputs , outputs , sub ):
17801769 # This method delegates much of the work to helper
0 commit comments