@@ -58,15 +58,13 @@ def _scaled_grouped_mm(
5858 """
5959 # TODO: Remove logging once prototype is more mature. This is currently very useful for development and debugging.
6060 if scaling_type == MoEScalingType .FP8_ROWWISE :
61- logger .debug ("Using fp8 rowwise for _scaled_grouped_mm" )
6261 return _Float8GroupedMM .apply (
6362 A ,
6463 B_t ,
6564 offs ,
6665 out_dtype ,
6766 )
6867 elif scaling_type == MoEScalingType .MXFP8 :
69- logger .debug ("Using mxfp8 for _scaled_grouped_mm" )
7068 block_size = 32 # TODO: should we make this configurable? plumb it through in a config somehow?
7169 return _MXFP8GroupedMM .apply (
7270 A ,
@@ -358,13 +356,17 @@ def backward(ctx, grad_out: torch.Tensor):
358356
359357 # B_data shape: (E, K, N)
360358 # B_scale shape: (E, K, N//block_size)
361- B_scales , B_data = to_mx (
359+ B_scales_ref , B_data_ref = to_mx (
362360 # TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
363361 B_t .contiguous (),
364362 elem_dtype = torch .float8_e4m3fn ,
365363 block_size = block_size ,
366364 )
367365
366+ # Experiment with cuda kernel
367+ B = B_t .transpose (- 2 , - 1 )
368+ B_scales , B_data = _to_mxfp8_dim1_3d (B , block_size = block_size )
369+
368370 # Convert scales to blocked format for 2d-3d grouped mm
369371 grad_out_scales_blocked = triton_mx_block_rearrange_2d_M_groups (
370372 grad_out_scale ,
@@ -376,21 +378,26 @@ def backward(ctx, grad_out: torch.Tensor):
376378 # grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
377379 grad_A = torch ._scaled_grouped_mm (
378380 grad_out_data ,
379- B_data . transpose ( - 2 , - 1 ) ,
381+ B_data ,
380382 grad_out_scales_blocked ,
381383 B_scales_blocked ,
382384 offs = offs ,
383385 out_dtype = out_dtype ,
384386 )
385387
386- # grad_out_t_data shape: (N, M )
388+ # grad_out_t_data shape: (M, N )
387389 # grad_out_t_scales shape: (N, M//block_size)
388- grad_out_t_scales , grad_out_t_data = to_mx (
389- # TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
390- grad_out . transpose ( - 2 , - 1 ). contiguous () ,
390+ grad_out_t_mx = _to_mxfp8_dim1_kernel_wrapper (
391+ grad_out ,
392+ block_size ,
391393 elem_dtype = torch .float8_e4m3fn ,
392- block_size = block_size ,
394+ hp_dtype = grad_out .dtype ,
395+ gemm_kernel_choice = MXGemmKernelChoice .CUTLASS , # Not used
396+ cast_kernel_choice = MXFP8Dim1CastKernelChoice .CUDA ,
397+ scale_calculation_mode = ScaleCalculationMode .FLOOR ,
393398 )
399+ grad_out_t_data = grad_out_t_mx .qdata
400+ grad_out_t_scales = grad_out_t_mx ._scale_e8m0
394401
395402 # Transpose A so we can scale along the M dimension, then un-transpose.
396403 # A_t_data shape: (K, M)
@@ -412,7 +419,6 @@ def backward(ctx, grad_out: torch.Tensor):
412419 _ , blocked_scale_group_offsets = compute_blocked_scale_offsets_for_K_groups (
413420 scale_group_offsets
414421 )
415-
416422 grad_out_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups (
417423 grad_out_t_scales ,
418424 scale_group_offsets ,
@@ -438,6 +444,40 @@ def backward(ctx, grad_out: torch.Tensor):
438444 return grad_A , grad_B_t , None , None , None
439445
440446
447+ def _to_mxfp8_dim1_3d (
448+ B : torch .Tensor ,
449+ block_size : int = 32 ,
450+ ) -> tuple [torch .Tensor , torch .Tensor ]:
451+ """
452+ Convert a 3D tensor to MXFP8 format with (block_size, 1) scaling granularity.
453+ """
454+ E , N , K = B .shape
455+ B_reshaped = B .reshape (E * N , K )
456+ B_t_mx = _to_mxfp8_dim1_kernel_wrapper (
457+ B_reshaped ,
458+ block_size ,
459+ elem_dtype = torch .float8_e4m3fn ,
460+ hp_dtype = B_reshaped .dtype ,
461+ gemm_kernel_choice = MXGemmKernelChoice .CUTLASS , # Not used
462+ cast_kernel_choice = MXFP8Dim1CastKernelChoice .CUDA ,
463+ scale_calculation_mode = ScaleCalculationMode .FLOOR ,
464+ )
465+ B_data = B_t_mx .qdata .t () # (K, E*N) -> (E*N, K)
466+ B_data = B_data .reshape (E , N , K ) # (E*N, K) -> (E, N, K)
467+ B_scales = B_t_mx ._scale_e8m0 .view (torch .uint8 ) # (K, E*N//block_size)
468+ B_scales = B_scales .reshape (
469+ K , E , N // block_size
470+ ) # (K, E*N//block_size) -> (K, E, N//block_size)
471+ B_scales = B_scales .permute (
472+ 1 , 0 , 2
473+ ) # (K, E, N//block_size) -> (E, K, N//block_size)
474+ B_scales = B_scales .view (torch .float8_e8m0fnu )
475+
476+ # TODO: Update cutlass grouped gemm to accept NT/TN/NN/TT layouts so we can avoid this conversion to column major
477+ B_data = B_data .transpose (- 2 , - 1 ).contiguous ().transpose (- 2 , - 1 )
478+ return B_scales , B_data
479+
480+
441481def _emulated_mxfp8_scaled_grouped_mm_2d_3d (
442482 A_data : torch .Tensor ,
443483 A_scale : torch .Tensor ,
@@ -606,3 +646,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
606646 # Perform bf16 grouped GEMM using dequantized A and B.
607647 out = torch ._grouped_mm (A , B , offs = offs , out_dtype = out_dtype )
608648 return out
649+
650+
651+ def round_up (x , y ):
652+ return ((x + y - 1 ) // y ) * y
0 commit comments