From ea4017692fd6eb82aa8985b1d279deee5e764023 Mon Sep 17 00:00:00 2001 From: Tianren Gao Date: Thu, 16 Oct 2025 17:29:38 -0700 Subject: [PATCH] fix partitionK --- tritonbench/operators/gemm/partition_k.py | 45 ++++++++++++++++------- 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/tritonbench/operators/gemm/partition_k.py b/tritonbench/operators/gemm/partition_k.py index dccf2071b..893692d58 100644 --- a/tritonbench/operators/gemm/partition_k.py +++ b/tritonbench/operators/gemm/partition_k.py @@ -144,7 +144,7 @@ def _matmul_partition_k( # See above `Pointer Arithmetic` section for details offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = (pid_pk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) % K + offs_k = (pid_pk * PK_SIZE + tl.arange(0, BLOCK_SIZE_K)) % K a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) @@ -157,13 +157,12 @@ def _matmul_partition_k( for k in range(0, tl.cdiv(PK_SIZE, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the K dimension. # If it is out of bounds, set it to 0. - # a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - # b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - a = tl.load(a_ptrs) - b = tl.load(b_ptrs) + k_mask = (pid_pk * PK_SIZE + k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) < K + a = tl.load(a_ptrs, mask=k_mask[None, :], other=0.0) + b = tl.load(b_ptrs, mask=k_mask[:, None], other=0.0) accumulator += tl.dot(a, b) - a_ptrs += PK_SIZE * stride_ak - b_ptrs += PK_SIZE * stride_bk + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -195,7 +194,7 @@ def _reduce( pid = tl.program_id(0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - pid_m = pid // num_pid_m + pid_m = pid // num_pid_n pid_n = pid % num_pid_n offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M @@ -220,17 +219,26 @@ def torch_reduction(c_buf, a): compiled_reduction = torch.compile(torch_reduction) -def _matmul_partition_k_impl(a, b, triton_reduce=False): +def _matmul_partition_k_impl(a, b, triton_reduce=False, partition_k=None): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" assert a.is_contiguous(), "Matrix A must be contiguous" assert b.is_contiguous(), "Matrix B must be contiguous" - # TODO: Tune on this parameter, currently 32 is best performing - partitionK = 32 - M, K = a.shape K, N = b.shape + + # Choose partition size + if partition_k is not None: + partitionK = partition_k + else: + # Use 32 partitions by default, only reduce for small K to maintain accuracy + partitionK = 32 if K >= 1024 else 8 + + # Ensure K is divisible by partitionK + while K % partitionK != 0 and partitionK > 1: + partitionK -= 1 + # Allocates output. partitionK_SIZE = K // partitionK @@ -312,5 +320,14 @@ def backward(ctx, grad_output): return grad_a, grad_b, None -def matmul_partition_k(a, b, triton_reduce=False): - return _PartitionKMatmul.apply(a, b, triton_reduce) +def matmul_partition_k(a, b, triton_reduce=False, partition_k=None): + """Matrix multiplication with partition-K parallelization. + + Args: + a: Left input tensor (M, K) + b: Right input tensor (K, N) + triton_reduce: If True, use Triton kernel for reduction, else use PyTorch + partition_k: Number of partitions to split K dimension into. + If None, automatically choose based on K dimension. + """ + return _matmul_partition_k_impl(a, b, triton_reduce, partition_k)