@@ -227,6 +227,9 @@ def valid_block_size(
227227 for i , k in enumerate (subscript ):
228228 if k is None :
229229 continue
230+ if k is Ellipsis :
231+ # Ellipsis is not supported in tensor descriptor mode
232+ return False
230233 size , stride = size_stride .popleft ()
231234 if isinstance (k , slice ):
232235 # Slices with steps are not supported in tensor descriptor mode
@@ -447,6 +450,14 @@ def codegen_store(
447450 )
448451
449452
453+ def _calculate_ellipsis_dims (
454+ index : list [object ], current_index : int , total_dims : int
455+ ) -> int :
456+ """Calculate how many dimensions an ellipsis should expand to."""
457+ remaining_indices = len (index ) - current_index - 1
458+ return total_dims - current_index - remaining_indices
459+
460+
450461class SubscriptIndexing (NamedTuple ):
451462 index_expr : ast .AST
452463 mask_expr : ast .AST
@@ -465,9 +476,19 @@ def compute_shape(
465476 input_size = collections .deque (tensor .size ())
466477 output_size = []
467478 env = CompileEnvironment .current ()
468- for k in index :
479+ for i , k in enumerate ( index ) :
469480 if k is None :
470481 output_size .append (1 )
482+ elif k is Ellipsis :
483+ # Ellipsis expands to consume all remaining dims except those after it
484+ ellipsis_dims = _calculate_ellipsis_dims (index , i , len (tensor .size ()))
485+ for _ in range (ellipsis_dims ):
486+ size = input_size .popleft ()
487+ if size != 1 :
488+ rdim = env .allocate_reduction_dimension (size )
489+ output_size .append (rdim .var )
490+ else :
491+ output_size .append (1 )
471492 elif isinstance (k , int ):
472493 input_size .popleft ()
473494 elif isinstance (k , torch .SymInt ):
@@ -517,6 +538,22 @@ def create(
517538 for n , k in enumerate (index ):
518539 if k is None :
519540 output_idx += 1
541+ elif k is Ellipsis :
542+ # Ellipsis expands to handle remaining dimensions
543+ ellipsis_dims = _calculate_ellipsis_dims (index , n , fake_value .ndim )
544+ for _ in range (ellipsis_dims ):
545+ expand = tile_strategy .expand_str (output_size , output_idx )
546+ size = fake_value .size (len (index_values ))
547+ if size != 1 :
548+ rdim = env .allocate_reduction_dimension (size )
549+ block_idx = rdim .block_id
550+ index_var = state .codegen .index_var (block_idx )
551+ index_values .append (f"({ index_var } ){ expand } " )
552+ if mask := state .codegen .mask_var (block_idx ):
553+ mask_values .setdefault (f"({ mask } ){ expand } " )
554+ else :
555+ index_values .append (f"tl.zeros([1], { dtype } ){ expand } " )
556+ output_idx += 1
520557 elif isinstance (k , int ):
521558 index_values .append (repr (k ))
522559 elif isinstance (k , torch .SymInt ):
@@ -729,8 +766,17 @@ def is_supported(
729766 # TODO(jansel): support block_ptr with extra_mask
730767 return False
731768 input_sizes = collections .deque (fake_tensor .size ())
732- for k in index :
733- input_size = 1 if k is None else input_sizes .popleft ()
769+ for n , k in enumerate (index ):
770+ if k is None :
771+ input_size = 1
772+ elif k is Ellipsis :
773+ # Skip appropriate number of dimensions for ellipsis
774+ ellipsis_dims = _calculate_ellipsis_dims (index , n , fake_tensor .ndim )
775+ for _ in range (ellipsis_dims ):
776+ input_sizes .popleft ()
777+ continue
778+ else :
779+ input_size = input_sizes .popleft ()
734780 if isinstance (k , torch .SymInt ):
735781 symbol = k ._sympy_ ()
736782 origin = None
@@ -780,9 +826,22 @@ def create(
780826 fake_value ,
781827 reshaped_size = SubscriptIndexing .compute_shape (fake_value , index ),
782828 )
783- for k in index :
829+ for n , k in enumerate ( index ) :
784830 if k is None :
785831 pass # handled by reshaped_size
832+ elif k is Ellipsis :
833+ # Ellipsis expands to handle remaining dimensions
834+ ellipsis_dims = _calculate_ellipsis_dims (index , n , fake_value .ndim )
835+ env = CompileEnvironment .current ()
836+ for _ in range (ellipsis_dims ):
837+ size = fake_value .size (len (res .offsets ))
838+ if size != 1 :
839+ rdim = env .allocate_reduction_dimension (size )
840+ res .offsets .append (state .codegen .offset_var (rdim .block_id ))
841+ res .block_shape .append (rdim .var )
842+ else :
843+ res .offsets .append ("0" )
844+ res .block_shape .append (1 )
786845 elif isinstance (k , int ):
787846 res .offsets .append (repr (k ))
788847 res .block_shape .append (1 )
0 commit comments