Skip to content

Commit cee29ab

Browse files
Automerge: [MLIR][ROCDL] Add math.clampf -> rocdl.fmed3 conversion (#163520)
Added Pattern for lowering `Math::ClampFOp` to `ROCDL::FMED3`. Also added `chipet` option to `MathToRocdl` pass to check for arch support ISA instructions Solves [#15072](llvm/llvm-project#157052) Reapplies llvm/llvm-project#160100 Un-reverts the merged llvm/llvm-project#163259, and fixes the error. --------- Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
2 parents cc51c1e + fbbffc1 commit cee29ab

File tree

6 files changed

+157
-13
lines changed

6 files changed

+157
-13
lines changed

mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
1010

1111
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
12+
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
1213
#include "mlir/IR/PatternMatch.h"
1314
#include <memory>
1415

@@ -19,8 +20,11 @@ class Pass;
1920
#include "mlir/Conversion/Passes.h.inc"
2021

2122
/// Populate the given list with patterns that convert from Math to ROCDL calls.
22-
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter,
23-
RewritePatternSet &patterns);
23+
// `chipset` specifies the AMDGPU chipset to target. If `std::nullopt`,
24+
// none of the chipset dependent patterns are added.
25+
void populateMathToROCDLConversionPatterns(
26+
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
27+
std::optional<amdgpu::Chipset> chipset);
2428
} // namespace mlir
2529

2630
#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_

mlir/include/mlir/Conversion/Passes.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,13 +778,20 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
778778
let summary = "Convert Math dialect to ROCDL library calls";
779779
let description = [{
780780
This pass converts supported Math ops to ROCDL library calls.
781+
782+
The chipset option specifies the target AMDGPU architecture. If the chipset
783+
is empty, none of the chipset-dependent patterns are added, and the pass
784+
will not attempt to parse the chipset.
781785
}];
782786
let dependentDialects = [
783787
"arith::ArithDialect",
784788
"func::FuncDialect",
785789
"ROCDL::ROCDLDialect",
786790
"vector::VectorDialect",
787791
];
792+
let options = [Option<"chipset", "chipset", "std::string",
793+
/*default=*/"\"\"",
794+
"Chipset that these operations will run on">];
788795
}
789796

790797
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,5 +484,5 @@ void mlir::populateGpuToROCDLConversionPatterns(
484484
GPUSubgroupBroadcastOpToROCDL>(converter);
485485
patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset);
486486

487-
populateMathToROCDLConversionPatterns(converter, patterns);
487+
populateMathToROCDLConversionPatterns(converter, patterns, chipset);
488488
}

mlir/lib/Conversion/MathToROCDL/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMathToROCDL
1111
Core
1212

1313
LINK_LIBS PUBLIC
14+
MLIRAMDGPUUtils
1415
MLIRDialectUtils
1516
MLIRFuncDialect
1617
MLIRGPUToGPURuntimeTransforms

mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
1111
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
1212
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13+
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
14+
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
1315
#include "mlir/Dialect/Func/IR/FuncOps.h"
1416
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1517
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -19,6 +21,7 @@
1921
#include "mlir/IR/PatternMatch.h"
2022
#include "mlir/Pass/Pass.h"
2123
#include "mlir/Transforms/DialectConversion.h"
24+
#include "llvm/Support/DebugLog.h"
2225

2326
#include "../GPUCommon/GPUOpsLowering.h"
2427
#include "../GPUCommon/OpToFuncCallLowering.h"
@@ -42,8 +45,46 @@ static void populateOpPatterns(const LLVMTypeConverter &converter,
4245
f32ApproxFunc, f16Func);
4346
}
4447

48+
struct ClampFOpConversion final
49+
: public ConvertOpToLLVMPattern<math::ClampFOp> {
50+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
51+
52+
LogicalResult
53+
matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor,
54+
ConversionPatternRewriter &rewriter) const override {
55+
// Only f16 and f32 types are supported by fmed3
56+
Type opTy = op.getType();
57+
Type resultType = getTypeConverter()->convertType(opTy);
58+
59+
if (auto vectorType = dyn_cast<VectorType>(opTy))
60+
opTy = vectorType.getElementType();
61+
62+
if (!isa<Float16Type, Float32Type>(opTy))
63+
return rewriter.notifyMatchFailure(
64+
op, "fmed3 only supports f16 and f32 types");
65+
66+
// Handle multi-dimensional vectors (converted to LLVM arrays)
67+
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType))
68+
return LLVM::detail::handleMultidimensionalVectors(
69+
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
70+
[&](Type llvm1DVectorTy, ValueRange operands) -> Value {
71+
typename math::ClampFOp::Adaptor adaptor(operands);
72+
return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy,
73+
adaptor.getValue(), adaptor.getMin(),
74+
adaptor.getMax());
75+
},
76+
rewriter);
77+
78+
// Handle 1D vectors and scalars directly
79+
rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(),
80+
op.getMin(), op.getMax());
81+
return success();
82+
}
83+
};
84+
4585
void mlir::populateMathToROCDLConversionPatterns(
46-
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
86+
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
87+
std::optional<amdgpu::Chipset> chipset) {
4788
// Handled by mathToLLVM: math::AbsIOp
4889
// Handled by mathToLLVM: math::AbsFOp
4990
// Handled by mathToLLVM: math::CopySignOp
@@ -118,15 +159,21 @@ void mlir::populateMathToROCDLConversionPatterns(
118159
// worth creating a separate pass for it.
119160
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
120161
"__ocml_fmod_f64", "__ocml_fmod_f16");
162+
163+
if (chipset.has_value() && chipset->majorVersion >= 9) {
164+
patterns.add<ClampFOpConversion>(converter);
165+
} else {
166+
LDBG() << "Chipset dependent patterns were not added";
167+
}
121168
}
122169

123-
namespace {
124-
struct ConvertMathToROCDLPass
125-
: public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
126-
ConvertMathToROCDLPass() = default;
170+
struct ConvertMathToROCDLPass final
171+
: impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
172+
using impl::ConvertMathToROCDLBase<
173+
ConvertMathToROCDLPass>::ConvertMathToROCDLBase;
174+
127175
void runOnOperation() override;
128176
};
129-
} // namespace
130177

131178
void ConvertMathToROCDLPass::runOnOperation() {
132179
auto m = getOperation();
@@ -135,10 +182,21 @@ void ConvertMathToROCDLPass::runOnOperation() {
135182
RewritePatternSet patterns(&getContext());
136183
LowerToLLVMOptions options(ctx, DataLayout(m));
137184
LLVMTypeConverter converter(ctx, options);
138-
populateMathToROCDLConversionPatterns(converter, patterns);
185+
186+
FailureOr<amdgpu::Chipset> maybeChipset;
187+
if (!chipset.empty()) {
188+
maybeChipset = amdgpu::Chipset::parse(chipset);
189+
if (failed(maybeChipset))
190+
return signalPassFailure();
191+
}
192+
populateMathToROCDLConversionPatterns(
193+
converter, patterns,
194+
succeeded(maybeChipset) ? std::optional(*maybeChipset) : std::nullopt);
195+
139196
ConversionTarget target(getContext());
140-
target.addLegalDialect<BuiltinDialect, func::FuncDialect,
141-
vector::VectorDialect, LLVM::LLVMDialect>();
197+
target
198+
.addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect,
199+
LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
142200
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
143201
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
144202
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,

mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
// RUN: mlir-opt %s -convert-math-to-rocdl -allow-unregistered-dialect -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx803})' | FileCheck %s --check-prefix=PRE9
2+
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx942})' | FileCheck %s --check-prefix=POST9
23

34
module @test_module {
45
// CHECK: llvm.func @__ocml_fmod_f16(f16, f16) -> f16
@@ -596,3 +597,76 @@ module @test_module {
596597
func.return %result : vector<2x2xf16>
597598
}
598599
}
600+
601+
// -----
602+
603+
// f16 clamp → rocdl.fmed3 on gfx9+
604+
// CHECK-LABEL: func.func @clampf_f16
605+
func.func @clampf_f16(%x: f16, %lo: f16, %hi: f16) -> f16 {
606+
%r = math.clampf %x to [%lo, %hi] : f16
607+
return %r : f16
608+
// POST9: rocdl.fmed3 {{.*}} : f16
609+
// PRE9-NOT: rocdl.fmed3
610+
// PRE9: math.clampf {{.*}} : f16
611+
}
612+
613+
// f32 clamp → rocdl.fmed3 on gfx9+
614+
// CHECK-LABEL: func.func @clampf_f32
615+
func.func @clampf_f32(%x: f32, %lo: f32, %hi: f32) -> f32 {
616+
%r = math.clampf %x to [%lo, %hi] : f32
617+
return %r : f32
618+
// POST9: rocdl.fmed3 {{.*}} : f32
619+
// PRE9-NOT: rocdl.fmed3
620+
// PRE9: math.clampf {{.*}} : f32
621+
}
622+
623+
// -----
624+
625+
// Vector f16 clamp → rocdl.fmed3 on gfx9+
626+
// CHECK-LABEL: func.func @clampf_vector_f16
627+
func.func @clampf_vector_f16(%x: vector<2xf16>, %lo: vector<2xf16>, %hi: vector<2xf16>) -> vector<2xf16> {
628+
%r = math.clampf %x to [%lo, %hi] : vector<2xf16>
629+
return %r : vector<2xf16>
630+
// POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
631+
// PRE9-NOT: rocdl.fmed3
632+
// PRE9: math.clampf {{.*}} : vector<2xf16>
633+
}
634+
635+
// -----
636+
637+
// Vector f32 clamp → rocdl.fmed3 on gfx9+
638+
// CHECK-LABEL: func.func @clampf_vector_f32
639+
func.func @clampf_vector_f32(%x: vector<2xf32>, %lo: vector<2xf32>, %hi: vector<2xf32>) -> vector<2xf32> {
640+
%r = math.clampf %x to [%lo, %hi] : vector<2xf32>
641+
return %r : vector<2xf32>
642+
// POST9: rocdl.fmed3 {{.*}} : vector<2xf32>
643+
// PRE9-NOT: rocdl.fmed3
644+
// PRE9: math.clampf {{.*}} : vector<2xf32>
645+
}
646+
647+
// -----
648+
649+
// Multi-dimensional vector f16 clamp → rocdl.fmed3 on gfx9+ (unrolled to 1D vectors)
650+
// CHECK-LABEL: func.func @clampf_vector_2d_f16
651+
func.func @clampf_vector_2d_f16(%x: vector<2x2xf16>, %lo: vector<2x2xf16>, %hi: vector<2x2xf16>) -> vector<2x2xf16> {
652+
%r = math.clampf %x to [%lo, %hi] : vector<2x2xf16>
653+
return %r : vector<2x2xf16>
654+
// POST9: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>>
655+
// POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
656+
// POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
657+
// POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
658+
// POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
659+
// POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
660+
// POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
661+
// PRE9-NOT: rocdl.fmed3
662+
// PRE9: math.clampf {{.*}} : vector<2x2xf16>
663+
}
664+
665+
// -----
666+
// CHECK-LABEL: func.func @clampf_bf16
667+
func.func @clampf_bf16(%x: bf16, %lo: bf16, %hi: bf16) -> bf16 {
668+
%r = math.clampf %x to [%lo, %hi] : bf16
669+
return %r : bf16
670+
// CHECK: math.clampf {{.*}} : bf16
671+
// CHECK-NOT: rocdl.fmed3
672+
}

0 commit comments

Comments
 (0)