@@ -189,30 +189,39 @@ def allocate_block_size(
189189 return idx
190190
191191 def allocate_reduction_dimension (self , size : torch .SymInt | int ) -> BlockSizeInfo :
192- # Check if this size is already a registered block size
193- if isinstance (size , torch .SymInt ):
194- from .host_function import HostFunction
195-
196- expr = size ._sympy_ ()
197- origin_info = HostFunction .current ().expr_to_origin .get (expr )
198- if origin_info and isinstance (origin_info .origin , BlockSizeOrigin ):
199- block_idx = origin_info .origin .block_id
200- # Return the existing block size if it's a reduction dimension
201- if self .block_sizes [block_idx ].reduction :
202- return self .block_sizes [block_idx ]
192+ # Quick return for existing reduction with same size
193+ for info in self .block_sizes :
194+ if info .reduction and info .size == size :
195+ return info
203196
204- # Check for existing reduction dimensions with the same size
205- for rdim in self .block_sizes :
206- if rdim .reduction and rdim .size == size :
207- return rdim
197+ # For SymInt, check if we can reuse existing block for symbolic equality
198+ if isinstance (size , torch .SymInt ):
199+ sym = size ._sympy_ ()
200+ block_id = self .get_block_id (size )
201+
202+ # Return existing reduction by block_id
203+ if block_id is not None and self .block_sizes [block_id ].reduction :
204+ return self .block_sizes [block_id ]
205+
206+ # Clone non-reduction block as reduction for symbolic equality
207+ for idx , info in enumerate (self .block_sizes ):
208+ if not info .reduction and (idx == block_id or sym == info .symbol ()):
209+ reduction_loop_index = sum (int (b .reduction ) for b in self .block_sizes )
210+ rdim_idx = self .allocate_block_size (
211+ size ,
212+ reduction = True ,
213+ source = ReductionLoopBlockSizeSource (reduction_loop_index ),
214+ hint = next_power_of_2 (self .size_hint (size )),
215+ )
216+ self .block_sizes [rdim_idx ].var = info .var
217+ return self .block_sizes [rdim_idx ]
208218
209- # Allocate a new reduction dimension
219+ # Allocate new reduction dimension
220+ reduction_loop_index = sum (int (info .reduction ) for info in self .block_sizes )
210221 rdim_idx = self .allocate_block_size (
211222 size ,
212223 reduction = True ,
213- source = ReductionLoopBlockSizeSource (
214- sum ([int (bs .reduction ) for bs in self .block_sizes ])
215- ),
224+ source = ReductionLoopBlockSizeSource (reduction_loop_index ),
216225 hint = next_power_of_2 (self .size_hint (size )),
217226 )
218227 return self .block_sizes [rdim_idx ]
@@ -269,6 +278,84 @@ def cached_create_unbacked_symint(
269278 self ._symint_cache [key ] = result
270279 return result
271280
281+
282+ def register_tile_index_tensor_block_id (self , tensor : torch .Tensor , block_id : int ) -> None :
283+ """Annotate ``tensor`` as originating from ``tile.index`` with ``block_id`` provenance."""
284+ tensor ._tile_index_block_id = block_id # type: ignore[attr-defined]
285+
286+ def get_tile_index_tensor_block_id (self , tensor : torch .Tensor ) -> int | None :
287+ """Return the originating ``tile.index`` block id if present."""
288+ return getattr (tensor , "_tile_index_block_id" , None )
289+
290+ def get_indexer_output_dims (
291+ self ,
292+ indexer_tensor : torch .Tensor ,
293+ base_dim_size : int | torch .SymInt | None ,
294+ ) -> list [int | torch .SymInt ]:
295+ """Map a tensor indexer's shape to the output dimensions for advanced indexing."""
296+ dims = list (indexer_tensor .size ())
297+ non_broadcast_dims = [d for d in dims if self .size_hint (d ) != 1 ]
298+
299+ # Multi-dimensional indexer - return full shape
300+ if len (non_broadcast_dims ) > 1 :
301+ return dims
302+
303+ # Try to find block_id from various sources
304+ block_id = (
305+ self .get_tile_index_tensor_block_id (indexer_tensor )
306+ or (self .get_block_id (base_dim_size ) if base_dim_size else None )
307+ or (self .get_block_id (non_broadcast_dims [0 ]) if non_broadcast_dims else None )
308+ )
309+
310+ return [self .block_sizes [block_id ].var ] if block_id else (non_broadcast_dims or [1 ])
311+
312+ def tensor_indexer_broadcast_shape (
313+ self , tensors : typing .Sequence [torch .Tensor ]
314+ ) -> list [int | torch .SymInt ] | None :
315+ """Compute a shared broadcast shape for tensor indexers when needed."""
316+ tensor_list = [t for t in tensors if isinstance (t , torch .Tensor )]
317+ if not tensor_list or all (self .get_tile_index_tensor_block_id (t ) for t in tensor_list ):
318+ return None
319+
320+ shapes = [list (t .size ()) for t in tensor_list ]
321+
322+ # Inline compute_broadcast_shape_for_tensor_indexers logic
323+ if not shapes :
324+ return []
325+
326+ # Special case: multiple 1D tensors form a Cartesian product
327+ if all (len (s ) == 1 for s in shapes ) and len (shapes ) > 1 :
328+ return [s [0 ] for s in shapes ]
329+
330+ # General broadcasting case
331+ max_ndim = max (len (s ) for s in shapes )
332+ padded = [([1 ] * (max_ndim - len (s )) + s ) for s in shapes ]
333+ return [
334+ next ((d for d in dims if self .size_hint (d ) != 1 ), 1 )
335+ for dims in zip (* padded , strict = True )
336+ ]
337+
338+ def new_index_result (
339+ self , tensor : torch .Tensor , output_shape : typing .Sequence [int | torch .SymInt ]
340+ ) -> torch .Tensor :
341+ """Create a new tensor for indexing/view ops while preserving tile index provenance."""
342+ # Inline resolve_tile_index_shape logic
343+ block_id = self .get_tile_index_tensor_block_id (tensor )
344+ if not block_id :
345+ resolved_shape = list (output_shape )
346+ else :
347+ resolved_shape = list (output_shape )
348+ non_broadcast = [i for i , s in enumerate (resolved_shape ) if self .size_hint (s ) != 1 ]
349+ if len (non_broadcast ) == 1 :
350+ resolved_shape [non_broadcast [0 ]] = self .block_sizes [block_id ].var
351+ elif len (non_broadcast ) > 1 :
352+ block_id = None
353+
354+ result = tensor .new_empty (resolved_shape )
355+ if block_id is not None :
356+ self .register_tile_index_tensor_block_id (result , block_id )
357+ return result
358+
272359 def to_fake (self , obj : object , origin : Origin ) -> object :
273360 if obj is None :
274361 return None
@@ -351,6 +438,10 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
351438 self .fake_mode , tensor , shape_env = self .shape_env , source = source
352439 )
353440 self .input_sources [result ] = source
441+ if hasattr (tensor , "_tile_index_block_id" ):
442+ self .register_tile_index_tensor_block_id (
443+ result , typing .cast (int , getattr (tensor , "_tile_index_block_id" ))
444+ )
354445 if isinstance (source , LocalSource ):
355446 for i , s in enumerate (result .size ()):
356447 if isinstance (s , torch .SymInt ) and isinstance (
0 commit comments