@@ -189,30 +189,39 @@ def allocate_block_size(
189189 return idx
190190
191191 def allocate_reduction_dimension (self , size : torch .SymInt | int ) -> BlockSizeInfo :
192- # Check if this size is already a registered block size
193- if isinstance (size , torch .SymInt ):
194- from .host_function import HostFunction
195-
196- expr = size ._sympy_ ()
197- origin_info = HostFunction .current ().expr_to_origin .get (expr )
198- if origin_info and isinstance (origin_info .origin , BlockSizeOrigin ):
199- block_idx = origin_info .origin .block_id
200- # Return the existing block size if it's a reduction dimension
201- if self .block_sizes [block_idx ].reduction :
202- return self .block_sizes [block_idx ]
192+ # Quick return for existing reduction with same size
193+ for info in self .block_sizes :
194+ if info .reduction and info .size == size :
195+ return info
203196
204- # Check for existing reduction dimensions with the same size
205- for rdim in self .block_sizes :
206- if rdim .reduction and rdim .size == size :
207- return rdim
197+ # For SymInt, check if we can reuse existing block for symbolic equality
198+ if isinstance (size , torch .SymInt ):
199+ sym = size ._sympy_ ()
200+ block_id = self .get_block_id (size )
201+
202+ # Return existing reduction by block_id
203+ if block_id is not None and self .block_sizes [block_id ].reduction :
204+ return self .block_sizes [block_id ]
205+
206+ # Clone non-reduction block as reduction for symbolic equality
207+ for idx , info in enumerate (self .block_sizes ):
208+ if not info .reduction and (idx == block_id or sym == info .symbol ()):
209+ reduction_loop_index = sum (int (b .reduction ) for b in self .block_sizes )
210+ rdim_idx = self .allocate_block_size (
211+ size ,
212+ reduction = True ,
213+ source = ReductionLoopBlockSizeSource (reduction_loop_index ),
214+ hint = next_power_of_2 (self .size_hint (size )),
215+ )
216+ self .block_sizes [rdim_idx ].var = info .var
217+ return self .block_sizes [rdim_idx ]
208218
209- # Allocate a new reduction dimension
219+ # Allocate new reduction dimension
220+ reduction_loop_index = sum (int (info .reduction ) for info in self .block_sizes )
210221 rdim_idx = self .allocate_block_size (
211222 size ,
212223 reduction = True ,
213- source = ReductionLoopBlockSizeSource (
214- sum ([int (bs .reduction ) for bs in self .block_sizes ])
215- ),
224+ source = ReductionLoopBlockSizeSource (reduction_loop_index ),
216225 hint = next_power_of_2 (self .size_hint (size )),
217226 )
218227 return self .block_sizes [rdim_idx ]
@@ -269,6 +278,73 @@ def cached_create_unbacked_symint(
269278 self ._symint_cache [key ] = result
270279 return result
271280
281+
282+ def register_tile_index_tensor_block_id (self , tensor : torch .Tensor , block_id : int ) -> None :
283+ """Annotate ``tensor`` as originating from ``tile.index`` with ``block_id`` provenance."""
284+ tensor ._tile_index_block_id = block_id # type: ignore[attr-defined]
285+
286+ def get_tile_index_tensor_block_id (self , tensor : torch .Tensor ) -> int | None :
287+ """Return the originating ``tile.index`` block id if present."""
288+ return getattr (tensor , "_tile_index_block_id" , None )
289+
290+ def get_indexer_output_dims (
291+ self ,
292+ indexer_tensor : torch .Tensor ,
293+ base_dim_size : int | torch .SymInt | None ,
294+ ) -> list [int | torch .SymInt ]:
295+ """Map a tensor indexer's shape to the output dimensions for advanced indexing."""
296+ dims = list (indexer_tensor .size ())
297+ non_broadcast_dims = [d for d in dims if self .size_hint (d ) != 1 ]
298+
299+ # Multi-dimensional indexer - return full shape
300+ if len (non_broadcast_dims ) > 1 :
301+ return dims
302+
303+ # Try to find block_id from various sources
304+ block_id = (
305+ self .get_tile_index_tensor_block_id (indexer_tensor )
306+ or (self .get_block_id (base_dim_size ) if base_dim_size else None )
307+ or (self .get_block_id (non_broadcast_dims [0 ]) if non_broadcast_dims else None )
308+ )
309+
310+ return [self .block_sizes [block_id ].var ] if block_id else (non_broadcast_dims or [1 ])
311+
312+ def tensor_indexer_broadcast_shape (
313+ self , tensors : typing .Sequence [torch .Tensor ]
314+ ) -> list [int | torch .SymInt ] | None :
315+ """Compute a shared broadcast shape for tensor indexers when needed."""
316+ tensor_list = [t for t in tensors if isinstance (t , torch .Tensor )]
317+ if not tensor_list or all (self .get_tile_index_tensor_block_id (t ) for t in tensor_list ):
318+ return None
319+
320+ shapes = [list (t .size ()) for t in tensor_list ]
321+ return compute_broadcast_shape_for_tensor_indexers (shapes , self )
322+
323+ def resolve_tile_index_shape (
324+ self , input_tensor : torch .Tensor , output_shape : typing .Sequence [int | torch .SymInt ]
325+ ) -> tuple [list [int | torch .SymInt ], int | None ]:
326+ """Resolve the symbolic shape for tensors derived from ``tile.index``."""
327+ block_id = self .get_tile_index_tensor_block_id (input_tensor )
328+ if not block_id :
329+ return list (output_shape ), None
330+
331+ resolved = list (output_shape )
332+ non_broadcast = [i for i , s in enumerate (resolved ) if self .size_hint (s ) != 1 ]
333+ if len (non_broadcast ) == 1 :
334+ resolved [non_broadcast [0 ]] = self .block_sizes [block_id ].var
335+ return resolved , block_id
336+ return resolved , block_id if len (non_broadcast ) == 0 else None
337+
338+ def new_index_result (
339+ self , tensor : torch .Tensor , output_shape : typing .Sequence [int | torch .SymInt ]
340+ ) -> torch .Tensor :
341+ """Create a new tensor for indexing/view ops while preserving tile index provenance."""
342+ resolved_shape , block_id = self .resolve_tile_index_shape (tensor , output_shape )
343+ result = tensor .new_empty (resolved_shape )
344+ if block_id is not None :
345+ self .register_tile_index_tensor_block_id (result , block_id )
346+ return result
347+
272348 def to_fake (self , obj : object , origin : Origin ) -> object :
273349 if obj is None :
274350 return None
@@ -351,6 +427,10 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
351427 self .fake_mode , tensor , shape_env = self .shape_env , source = source
352428 )
353429 self .input_sources [result ] = source
430+ if hasattr (tensor , "_tile_index_block_id" ):
431+ self .register_tile_index_tensor_block_id (
432+ result , typing .cast (int , getattr (tensor , "_tile_index_block_id" ))
433+ )
354434 if isinstance (source , LocalSource ):
355435 for i , s in enumerate (result .size ()):
356436 if isinstance (s , torch .SymInt ) and isinstance (
@@ -643,6 +723,31 @@ def _has_unbacked(expr: sympy.Expr) -> bool:
643723 return any (n .name .startswith ("u" ) for n in expr .free_symbols ) # pyright: ignore[reportAttributeAccessIssue]
644724
645725
726+ def compute_broadcast_shape_for_tensor_indexers (
727+ shapes : list [list [int | torch .SymInt ]],
728+ env : "CompileEnvironment"
729+ ) -> list [int | torch .SymInt ]:
730+ """Compute broadcast shape for multiple tensor indexers using right-aligned broadcasting.
731+
732+ For multiple 1D tensors, this should return a shape that represents their Cartesian product.
733+ For example, two tensors of shape [8] and [8] should broadcast to shape [8, 8].
734+ """
735+ if not shapes :
736+ return []
737+
738+ # Special case: multiple 1D tensors form a Cartesian product
739+ if all (len (s ) == 1 for s in shapes ) and len (shapes ) > 1 :
740+ return [s [0 ] for s in shapes ]
741+
742+ # General broadcasting case
743+ max_ndim = max (len (s ) for s in shapes )
744+ padded = [([1 ] * (max_ndim - len (s )) + s ) for s in shapes ]
745+ return [
746+ next ((d for d in dims if env .size_hint (d ) != 1 ), 1 )
747+ for dims in zip (* padded , strict = True )
748+ ]
749+
750+
646751def format_shape (shape : tuple [object , ...]) -> str :
647752 def _format_dim (dim : object ) -> str :
648753 if isinstance (dim , torch .SymInt ):
0 commit comments