Skip to content

Commit eb7cbac

Browse files
committed
wip
1 parent 77f08d9 commit eb7cbac

File tree

7 files changed

+712
-234
lines changed

7 files changed

+712
-234
lines changed

helion/_compiler/compile_environment.py

Lines changed: 110 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -189,30 +189,39 @@ def allocate_block_size(
189189
return idx
190190

191191
def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInfo:
192-
# Check if this size is already a registered block size
193-
if isinstance(size, torch.SymInt):
194-
from .host_function import HostFunction
195-
196-
expr = size._sympy_()
197-
origin_info = HostFunction.current().expr_to_origin.get(expr)
198-
if origin_info and isinstance(origin_info.origin, BlockSizeOrigin):
199-
block_idx = origin_info.origin.block_id
200-
# Return the existing block size if it's a reduction dimension
201-
if self.block_sizes[block_idx].reduction:
202-
return self.block_sizes[block_idx]
192+
# Quick return for existing reduction with same size
193+
for info in self.block_sizes:
194+
if info.reduction and info.size == size:
195+
return info
203196

204-
# Check for existing reduction dimensions with the same size
205-
for rdim in self.block_sizes:
206-
if rdim.reduction and rdim.size == size:
207-
return rdim
197+
# For SymInt, check if we can reuse existing block for symbolic equality
198+
if isinstance(size, torch.SymInt):
199+
sym = size._sympy_()
200+
block_id = self.get_block_id(size)
201+
202+
# Return existing reduction by block_id
203+
if block_id is not None and self.block_sizes[block_id].reduction:
204+
return self.block_sizes[block_id]
205+
206+
# Clone non-reduction block as reduction for symbolic equality
207+
for idx, info in enumerate(self.block_sizes):
208+
if not info.reduction and (idx == block_id or sym == info.symbol()):
209+
reduction_loop_index = sum(int(b.reduction) for b in self.block_sizes)
210+
rdim_idx = self.allocate_block_size(
211+
size,
212+
reduction=True,
213+
source=ReductionLoopBlockSizeSource(reduction_loop_index),
214+
hint=next_power_of_2(self.size_hint(size)),
215+
)
216+
self.block_sizes[rdim_idx].var = info.var
217+
return self.block_sizes[rdim_idx]
208218

209-
# Allocate a new reduction dimension
219+
# Allocate new reduction dimension
220+
reduction_loop_index = sum(int(info.reduction) for info in self.block_sizes)
210221
rdim_idx = self.allocate_block_size(
211222
size,
212223
reduction=True,
213-
source=ReductionLoopBlockSizeSource(
214-
sum([int(bs.reduction) for bs in self.block_sizes])
215-
),
224+
source=ReductionLoopBlockSizeSource(reduction_loop_index),
216225
hint=next_power_of_2(self.size_hint(size)),
217226
)
218227
return self.block_sizes[rdim_idx]
@@ -269,6 +278,84 @@ def cached_create_unbacked_symint(
269278
self._symint_cache[key] = result
270279
return result
271280

281+
282+
def register_tile_index_tensor_block_id(self, tensor: torch.Tensor, block_id: int) -> None:
283+
"""Annotate ``tensor`` as originating from ``tile.index`` with ``block_id`` provenance."""
284+
tensor._tile_index_block_id = block_id # type: ignore[attr-defined]
285+
286+
def get_tile_index_tensor_block_id(self, tensor: torch.Tensor) -> int | None:
287+
"""Return the originating ``tile.index`` block id if present."""
288+
return getattr(tensor, "_tile_index_block_id", None)
289+
290+
def get_indexer_output_dims(
291+
self,
292+
indexer_tensor: torch.Tensor,
293+
base_dim_size: int | torch.SymInt | None,
294+
) -> list[int | torch.SymInt]:
295+
"""Map a tensor indexer's shape to the output dimensions for advanced indexing."""
296+
dims = list(indexer_tensor.size())
297+
non_broadcast_dims = [d for d in dims if self.size_hint(d) != 1]
298+
299+
# Multi-dimensional indexer - return full shape
300+
if len(non_broadcast_dims) > 1:
301+
return dims
302+
303+
# Try to find block_id from various sources
304+
block_id = (
305+
self.get_tile_index_tensor_block_id(indexer_tensor)
306+
or (self.get_block_id(base_dim_size) if base_dim_size else None)
307+
or (self.get_block_id(non_broadcast_dims[0]) if non_broadcast_dims else None)
308+
)
309+
310+
return [self.block_sizes[block_id].var] if block_id else (non_broadcast_dims or [1])
311+
312+
def tensor_indexer_broadcast_shape(
313+
self, tensors: typing.Sequence[torch.Tensor]
314+
) -> list[int | torch.SymInt] | None:
315+
"""Compute a shared broadcast shape for tensor indexers when needed."""
316+
tensor_list = [t for t in tensors if isinstance(t, torch.Tensor)]
317+
if not tensor_list or all(self.get_tile_index_tensor_block_id(t) for t in tensor_list):
318+
return None
319+
320+
shapes = [list(t.size()) for t in tensor_list]
321+
322+
# Inline compute_broadcast_shape_for_tensor_indexers logic
323+
if not shapes:
324+
return []
325+
326+
# Special case: multiple 1D tensors form a Cartesian product
327+
if all(len(s) == 1 for s in shapes) and len(shapes) > 1:
328+
return [s[0] for s in shapes]
329+
330+
# General broadcasting case
331+
max_ndim = max(len(s) for s in shapes)
332+
padded = [([1] * (max_ndim - len(s)) + s) for s in shapes]
333+
return [
334+
next((d for d in dims if self.size_hint(d) != 1), 1)
335+
for dims in zip(*padded, strict=True)
336+
]
337+
338+
def new_index_result(
339+
self, tensor: torch.Tensor, output_shape: typing.Sequence[int | torch.SymInt]
340+
) -> torch.Tensor:
341+
"""Create a new tensor for indexing/view ops while preserving tile index provenance."""
342+
# Inline resolve_tile_index_shape logic
343+
block_id = self.get_tile_index_tensor_block_id(tensor)
344+
if not block_id:
345+
resolved_shape = list(output_shape)
346+
else:
347+
resolved_shape = list(output_shape)
348+
non_broadcast = [i for i, s in enumerate(resolved_shape) if self.size_hint(s) != 1]
349+
if len(non_broadcast) == 1:
350+
resolved_shape[non_broadcast[0]] = self.block_sizes[block_id].var
351+
elif len(non_broadcast) > 1:
352+
block_id = None
353+
354+
result = tensor.new_empty(resolved_shape)
355+
if block_id is not None:
356+
self.register_tile_index_tensor_block_id(result, block_id)
357+
return result
358+
272359
def to_fake(self, obj: object, origin: Origin) -> object:
273360
if obj is None:
274361
return None
@@ -351,6 +438,10 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
351438
self.fake_mode, tensor, shape_env=self.shape_env, source=source
352439
)
353440
self.input_sources[result] = source
441+
if hasattr(tensor, "_tile_index_block_id"):
442+
self.register_tile_index_tensor_block_id(
443+
result, typing.cast(int, getattr(tensor, "_tile_index_block_id"))
444+
)
354445
if isinstance(source, LocalSource):
355446
for i, s in enumerate(result.size()):
356447
if isinstance(s, torch.SymInt) and isinstance(

0 commit comments

Comments
 (0)