|
18 | 18 | from .host_function import HostFunction |
19 | 19 | from .tile_strategy import DeviceLoopState |
20 | 20 | from .utils import compute_slice_size |
| 21 | +from .utils import get_slice_start |
21 | 22 | from .variable_origin import BlockSizeOrigin |
22 | 23 |
|
23 | 24 | if TYPE_CHECKING: |
@@ -126,6 +127,30 @@ def _handle_remaining_index_dimensions( |
126 | 127 | return output_idx |
127 | 128 |
|
128 | 129 |
|
| 130 | +def _generate_slice_index( |
| 131 | + start: int | torch.SymInt, |
| 132 | + index_var: str, |
| 133 | + expand: str, |
| 134 | + step: int | None = None, |
| 135 | +) -> str: |
| 136 | + """Generate slice index expression with optional step.""" |
| 137 | + if step is not None: |
| 138 | + # Strided index: start + index * step |
| 139 | + return f"({start} + ({index_var}) * {step}){expand}" |
| 140 | + if start != 0: |
| 141 | + # Index with offset: start + index |
| 142 | + return f"({start} + ({index_var})){expand}" |
| 143 | + # Simple index |
| 144 | + return f"({index_var}){expand}" |
| 145 | + |
| 146 | + |
| 147 | +def _generate_offset_expr(start: int | torch.SymInt, offset: str) -> str: |
| 148 | + """Generate offset expression with optional start.""" |
| 149 | + if start != 0: |
| 150 | + return f"({start} + {offset})" |
| 151 | + return offset |
| 152 | + |
| 153 | + |
129 | 154 | class IndexingStrategy: |
130 | 155 | def codegen_load( |
131 | 156 | self, |
@@ -627,7 +652,6 @@ def compute_shape( |
627 | 652 | size = input_size.popleft() |
628 | 653 | # Handle slices with steps |
629 | 654 | slice_size = compute_slice_size(k, size) |
630 | | - |
631 | 655 | if slice_size != 1: |
632 | 656 | rdim = env.allocate_reduction_dimension(slice_size) |
633 | 657 | output_size.append(rdim.var) |
@@ -719,25 +743,29 @@ def create( |
719 | 743 | rdim = env.allocate_reduction_dimension(slice_size) |
720 | 744 | block_idx = rdim.block_id |
721 | 745 | index_var = state.codegen.index_var(block_idx) |
722 | | - # Generate strided index: start + index * step |
723 | 746 | index_values.append( |
724 | | - f"({start} + ({index_var}) * {step}){expand}" |
| 747 | + _generate_slice_index(start, index_var, expand, step) |
725 | 748 | ) |
726 | 749 | if mask := state.codegen.mask_var(block_idx): |
727 | 750 | mask_values.setdefault(f"({mask}){expand}") |
728 | 751 | else: |
729 | 752 | index_values.append(f"{start}{expand}") |
730 | 753 | else: |
731 | | - # Full slice or slice without step |
732 | | - if size != 1: |
733 | | - rdim = env.allocate_reduction_dimension(size) |
| 754 | + # Handle slices with start/stop but no step |
| 755 | + start = get_slice_start(k) |
| 756 | + slice_size = compute_slice_size(k, size) |
| 757 | + |
| 758 | + if slice_size != 1: |
| 759 | + rdim = env.allocate_reduction_dimension(slice_size) |
734 | 760 | block_idx = rdim.block_id |
735 | 761 | index_var = state.codegen.index_var(block_idx) |
736 | | - index_values.append(f"({index_var}){expand}") |
| 762 | + index_values.append( |
| 763 | + _generate_slice_index(start, index_var, expand) |
| 764 | + ) |
737 | 765 | if mask := state.codegen.mask_var(block_idx): |
738 | 766 | mask_values.setdefault(f"({mask}){expand}") |
739 | 767 | else: |
740 | | - index_values.append(f"tl.zeros([1], {dtype}){expand}") |
| 768 | + index_values.append(f"{start}{expand}") |
741 | 769 | output_idx += 1 |
742 | 770 | elif isinstance(k, torch.Tensor) and k.ndim == 1: |
743 | 771 | expand = tile_strategy.expand_str(output_size, output_idx) |
@@ -1025,8 +1053,19 @@ def create( |
1025 | 1053 | res.offsets.append(state.codegen.offset_var(rdim.block_id)) |
1026 | 1054 | res.block_shape.append(rdim.var) |
1027 | 1055 | else: |
1028 | | - res.offsets.append("0") |
1029 | | - res.block_shape.append(1) |
| 1056 | + # Handle slices with start/stop but no step |
| 1057 | + start = get_slice_start(k) |
| 1058 | + slice_size = compute_slice_size(k, size) |
| 1059 | + |
| 1060 | + if slice_size != 1: |
| 1061 | + env = CompileEnvironment.current() |
| 1062 | + rdim = env.allocate_reduction_dimension(slice_size) |
| 1063 | + offset = state.codegen.offset_var(rdim.block_id) |
| 1064 | + res.offsets.append(_generate_offset_expr(start, offset)) |
| 1065 | + res.block_shape.append(rdim.var) |
| 1066 | + else: |
| 1067 | + res.offsets.append(str(start)) |
| 1068 | + res.block_shape.append(1) |
1030 | 1069 | else: |
1031 | 1070 | raise exc.InvalidIndexingType(k) |
1032 | 1071 | res.validate() |
|
0 commit comments