Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import List, Tuple

import torch
from tabulate import tabulate
from tqdm import tqdm

# Assuming these imports based on the kernel location
from benchmarks.utils import benchmark_cuda_function_in_microseconds
from torchao.prototype.blockwise_fp8_training.kernels import (
triton_fp8_blockwise_act_quant_transposed_lhs,
)

device = torch.device("cuda")

# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000


@dataclass(frozen=True)
class ExperimentConfig:
input_shape: Tuple[int, int] # (M, K)
block_size: int


@dataclass(frozen=True)
class ExperimentResult:
# time
naive_us: float
triton_us: float
# mem bw
naive_gbps: float
triton_gbps: float


@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
result: ExperimentResult


def get_configs() -> List[ExperimentConfig]:
"""
Test configurations for typical transformer activation shapes.
Format: (batch_size * seq_len, hidden_dim)

Note: For transposed_lhs, M must be divisible by block_size
"""
# Llama-style shapes: various batch*seq_len sizes with typical hidden dims
# Ensuring M is divisible by block_size (128)
input_shapes = [
(512, 4096),
(1024, 4096),
(2048, 4096),
(4096, 4096),
(8192, 4096),

]

configs = []
block_sizes = [128] # Standard block size for FP8

for shape in input_shapes:
for block_size in block_sizes:
# Verify M is divisible by block_size
if shape[0] % block_size == 0:
configs.append(
ExperimentConfig(
input_shape=shape,
block_size=block_size,
)
)
return configs


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
M, K = config.input_shape
block_size = config.block_size

def naive_fp8_blockwise_quant_transposed(
x: torch.Tensor, block_size: int = 128
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Naive PyTorch reference implementation for blockwise FP8 quantization with transpose.

This version:
1. Computes column-wise scales (along dimension 0)
2. Outputs transposed quantized tensor (K, M) in row-major format
3. Outputs scales in shape (K, M//block_size)

Args:
x: Input tensor of shape (M, K)
block_size: Number of elements per block along M dimension

Returns:
y: Transposed quantized tensor in FP8, shape (K, M) in row-major
s: Reciprocal scales in column-major format (K, M//block_size)
"""
assert x.is_contiguous(), "Input must be contiguous"
assert x.size(0) % block_size == 0, "M must be divisible by block_size"

M, K = x.size()
num_blocks = M // block_size

# FP8 E4M3 constants
max_fp8_e4m3 = 448.0
min_fp8_e4m3 = -448.0
eps = 1e-12

# Reshape to (num_blocks, block_size, K) for block-wise operations along M
x_reshaped = x.view(num_blocks, block_size, K)

# Compute max absolute value per block along dimension 1 (block_size)
# Result shape: (num_blocks, K)
amax = torch.clamp(
x_reshaped.abs().amax(dim=1).to(torch.float64),
min=eps,
max=float('inf')
)

# Compute scales (num_blocks, K) -> (num_blocks, 1, K) for broadcasting
scale = (max_fp8_e4m3 / amax).to(torch.float32).unsqueeze(1)

# Quantize
y_reshaped = x_reshaped * scale
y_reshaped = torch.clamp(
y_reshaped, min=min_fp8_e4m3, max=max_fp8_e4m3)

# Reshape back to (M, K) then transpose to (K, M)
y = y_reshaped.view(M, K).t().contiguous().to(torch.float8_e4m3fn)

# Compute reciprocal scales - explicitly cast to float32
reciprocal_scale = (1.0 / scale.squeeze(1)).to(torch.float32)
# reciprocal_scale is (num_blocks, K), need to transpose to (K, num_blocks)
reciprocal_scale = reciprocal_scale.t().contiguous()

# Convert to column-major using as_strided (matching Triton kernel output)
s = x.new_empty(K, num_blocks, dtype=torch.float32).as_strided(
(K, num_blocks),
(1, K), # Column-major strides
)
s.copy_(reciprocal_scale)

return y, s

def verify_outputs(
y_naive: torch.Tensor,
s_naive: torch.Tensor,
y_triton: torch.Tensor,
s_triton: torch.Tensor,
input_tensor: torch.Tensor,
block_size: int,
rtol: float = 1e-2,
atol: float = 1e-2,
):
"""Verify that Triton and naive implementations produce similar results."""

# Verify output shapes
M, K = input_tensor.shape
expected_y_shape = (K, M)
expected_s_shape = (K, M // block_size)

assert y_naive.shape == expected_y_shape, f"Naive y shape mismatch: {y_naive.shape} vs {expected_y_shape}"
assert y_triton.shape == expected_y_shape, f"Triton y shape mismatch: {y_triton.shape} vs {expected_y_shape}"
assert s_naive.shape == expected_s_shape, f"Naive s shape mismatch: {s_naive.shape} vs {expected_s_shape}"
assert s_triton.shape == expected_s_shape, f"Triton s shape mismatch: {s_triton.shape} vs {expected_s_shape}"

# Convert FP8 back to float for comparison
y_naive_float = y_naive.to(torch.float32)
y_triton_float = y_triton.to(torch.float32)

# Check quantized values are close
if not torch.allclose(y_naive_float, y_triton_float, rtol=rtol, atol=atol):
max_diff = (y_naive_float - y_triton_float).abs().max().item()
print(f"WARNING: Quantized values differ! Max diff: {max_diff}")
print(
f" Naive range: [{y_naive_float.min():.3f}, {y_naive_float.max():.3f}]")
print(
f" Triton range: [{y_triton_float.min():.3f}, {y_triton_float.max():.3f}]")

# ROBUST FIX: Handle potential dtype mismatches from torch.compile
# Convert both scales to float32 before any operations
if s_naive.dtype != torch.float32:
print(
f"INFO: Converting naive scales from {s_naive.dtype} to float32")
s_naive = s_naive.to(torch.float32)

if s_triton.dtype != torch.float32:
print(
f"INFO: Converting Triton scales from {s_triton.dtype} to float32")
s_triton = s_triton.to(torch.float32)

# Check scales are close
# Note: scales are in column-major format, need to read them correctly
s_naive_rowmajor = s_naive.as_strided(
s_naive.shape, (s_naive.shape[1], 1))
s_triton_rowmajor = s_triton.as_strided(
s_triton.shape, (s_triton.shape[1], 1))

if not torch.allclose(s_naive_rowmajor, s_triton_rowmajor, rtol=rtol, atol=atol):
max_diff = (s_naive_rowmajor -
s_triton_rowmajor).abs().max().item()
print(f"WARNING: Scales differ! Max diff: {max_diff}")
print(
f" Naive scale range: [{s_naive_rowmajor.min():.6f}, {s_naive_rowmajor.max():.6f}]")
print(
f" Triton scale range: [{s_triton_rowmajor.min():.6f}, {s_triton_rowmajor.max():.6f}]")

input_tensor = torch.randn(
M, K,
dtype=torch.bfloat16,
device=device,
)

# Benchmark naive implementation
naive_impl_c = torch.compile(naive_fp8_blockwise_quant_transposed)

# Benchmark after warmup
y_naive, s_naive = naive_impl_c(input_tensor, block_size)
naive_time_us = benchmark_cuda_function_in_microseconds(
naive_impl_c,
input_tensor,
block_size,
)

# Benchmark Triton implementation
triton_impl_c = torch.compile(
triton_fp8_blockwise_act_quant_transposed_lhs)

# Benchmark after warmup
y_triton, s_triton = triton_impl_c(input_tensor, block_size)
triton_time_us = benchmark_cuda_function_in_microseconds(
triton_impl_c,
input_tensor,
block_size,
)

# Verify correctness
verify_outputs(y_naive, s_naive, y_triton,
s_triton, input_tensor, block_size)

# Memory bandwidth calculations
bytes_per_input_el = torch.finfo(torch.bfloat16).bits / 8
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
bytes_per_scale_el = 4 # float32

read_bytes = input_tensor.numel() * bytes_per_input_el
write_bytes = (
y_triton.numel() * bytes_per_output_el
+ s_triton.numel() * bytes_per_scale_el
)

naive_gbps = ((read_bytes + write_bytes) / 1e9) / (naive_time_us / 1e6)
triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6)

return ExperimentResult(
naive_us=naive_time_us,
triton_us=triton_time_us,
naive_gbps=naive_gbps,
triton_gbps=triton_gbps,
)


def print_results(experiments: List[Experiment]):
headers = [
"input_shape (M, K)",
"block_size",
"naive_us",
"triton_us",
"speedup",
"naive_gbps",
"triton_gbps",
]
rows = []
for experiment in experiments:
speedup = experiment.result.naive_us / experiment.result.triton_us
rows.append(
[
f"{experiment.config.input_shape[0]}x{experiment.config.input_shape[1]}",
experiment.config.block_size,
f"{experiment.result.naive_us:.2f}",
f"{experiment.result.triton_us:.2f}",
f"{speedup:.2f}x",
f"{experiment.result.naive_gbps:.1f}",
f"{experiment.result.triton_gbps:.1f}",
]
)
print(tabulate(rows, headers=headers, tablefmt="grid"))


def main():
torch.random.manual_seed(123)
configs = get_configs()
results = []

print(f"Running {len(configs)} benchmark configurations...\n")

for config in tqdm(configs, desc="Benchmarking"):
result = run_experiment(config)
results.append(Experiment(config=config, result=result))

print("\n" + "="*80)
print("BENCHMARK RESULTS - Transposed LHS Quantization")
print("="*80 + "\n")
print_results(results)


if __name__ == "__main__":
main()
Loading
Loading