Skip to content

Commit 585cda6

Browse files
authored
Revert "[MLIR][ROCDL] Add math.clampf -> rocdl.fmed3 conversion (#163259)"
This reverts commit 1e6df64.
1 parent 1e6df64 commit 585cda6

File tree

5 files changed

+12
-166
lines changed

5 files changed

+12
-166
lines changed

mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
1010

1111
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
12-
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
1312
#include "mlir/IR/PatternMatch.h"
1413
#include <memory>
1514

@@ -21,8 +20,7 @@ class Pass;
2120

2221
/// Populate the given list with patterns that convert from Math to ROCDL calls.
2322
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter,
24-
RewritePatternSet &patterns,
25-
amdgpu::Chipset chipset);
23+
RewritePatternSet &patterns);
2624
} // namespace mlir
2725

2826
#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_

mlir/include/mlir/Conversion/Passes.td

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -778,20 +778,13 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
778778
let summary = "Convert Math dialect to ROCDL library calls";
779779
let description = [{
780780
This pass converts supported Math ops to ROCDL library calls.
781-
782-
The chipset option specifies the target AMDGPU architecture. If the chipset
783-
is empty, none of the chipset-dependent patterns are added and the pass
784-
will not attempt to parse the chipset.
785781
}];
786782
let dependentDialects = [
787783
"arith::ArithDialect",
788784
"func::FuncDialect",
789785
"ROCDL::ROCDLDialect",
790786
"vector::VectorDialect",
791787
];
792-
let options = [Option<"chipset", "chipset", "std::string",
793-
/*default=*/"\"gfx000\"",
794-
"Chipset that these operations will run on">];
795788
}
796789

797790
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,5 +484,5 @@ void mlir::populateGpuToROCDLConversionPatterns(
484484
GPUSubgroupBroadcastOpToROCDL>(converter);
485485
patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset);
486486

487-
populateMathToROCDLConversionPatterns(converter, patterns, chipset);
487+
populateMathToROCDLConversionPatterns(converter, patterns);
488488
}

mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp

Lines changed: 9 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
1111
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
1212
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13-
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
14-
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
1513
#include "mlir/Dialect/Func/IR/FuncOps.h"
1614
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1715
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -44,65 +42,8 @@ static void populateOpPatterns(const LLVMTypeConverter &converter,
4442
f32ApproxFunc, f16Func);
4543
}
4644

47-
struct ClampFOpConversion final
48-
: public ConvertOpToLLVMPattern<math::ClampFOp> {
49-
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
50-
ClampFOpConversion(const LLVMTypeConverter &converter,
51-
amdgpu::Chipset chipset)
52-
: ConvertOpToLLVMPattern<math::ClampFOp>(converter), chipset(chipset) {}
53-
54-
LogicalResult
55-
matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor,
56-
ConversionPatternRewriter &rewriter) const override {
57-
// Only f16 and f32 types are supported by fmed3
58-
Type opTy = op.getType();
59-
auto resultType = getTypeConverter()->convertType(opTy);
60-
61-
if (auto vectorType = dyn_cast<VectorType>(opTy)) {
62-
opTy = vectorType.getElementType();
63-
}
64-
65-
if (!isa<Float16Type, Float32Type>(opTy)) {
66-
return rewriter.notifyMatchFailure(
67-
op, "fmed3 only supports f16 and f32 types");
68-
}
69-
70-
// Handle multi-dimensional vectors (converted to LLVM arrays)
71-
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType)) {
72-
// Handle multi-dimensional vectors (converted to LLVM arrays)
73-
return LLVM::detail::handleMultidimensionalVectors(
74-
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
75-
[&](Type llvm1DVectorTy, ValueRange operands) -> Value {
76-
typename math::ClampFOp::Adaptor adaptor(operands);
77-
return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy,
78-
adaptor.getValue(), adaptor.getMin(),
79-
adaptor.getMax());
80-
},
81-
rewriter);
82-
}
83-
84-
// Handle 1D vectors and scalars directly
85-
rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(),
86-
op.getMin(), op.getMax());
87-
return success();
88-
}
89-
90-
amdgpu::Chipset chipset;
91-
};
92-
93-
static void addChipsetDependentPatterns(const LLVMTypeConverter &converter,
94-
RewritePatternSet &patterns,
95-
amdgpu::Chipset chipset) {
96-
97-
// V_MED3_F16/F32 only exists in gfx9+ architectures
98-
if (chipset.majorVersion >= 9) {
99-
patterns.add<ClampFOpConversion>(converter, chipset);
100-
}
101-
}
102-
10345
void mlir::populateMathToROCDLConversionPatterns(
104-
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
105-
amdgpu::Chipset chipset) {
46+
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
10647
// Handled by mathToLLVM: math::AbsIOp
10748
// Handled by mathToLLVM: math::AbsFOp
10849
// Handled by mathToLLVM: math::CopySignOp
@@ -177,17 +118,15 @@ void mlir::populateMathToROCDLConversionPatterns(
177118
// worth creating a separate pass for it.
178119
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
179120
"__ocml_fmod_f64", "__ocml_fmod_f16");
180-
181-
addChipsetDependentPatterns(converter, patterns, chipset);
182121
}
183122

184-
struct ConvertMathToROCDLPass final
185-
: impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
186-
using impl::ConvertMathToROCDLBase<
187-
ConvertMathToROCDLPass>::ConvertMathToROCDLBase;
188-
123+
namespace {
124+
struct ConvertMathToROCDLPass
125+
: public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
126+
ConvertMathToROCDLPass() = default;
189127
void runOnOperation() override;
190128
};
129+
} // namespace
191130

192131
void ConvertMathToROCDLPass::runOnOperation() {
193132
auto m = getOperation();
@@ -196,20 +135,10 @@ void ConvertMathToROCDLPass::runOnOperation() {
196135
RewritePatternSet patterns(&getContext());
197136
LowerToLLVMOptions options(ctx, DataLayout(m));
198137
LLVMTypeConverter converter(ctx, options);
199-
200-
// Only populate chipset-dependent patterns if chipset is specified
201-
if (!chipset.empty()) {
202-
FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
203-
if (failed(maybeChipset)) {
204-
return signalPassFailure();
205-
}
206-
populateMathToROCDLConversionPatterns(converter, patterns, *maybeChipset);
207-
}
208-
138+
populateMathToROCDLConversionPatterns(converter, patterns);
209139
ConversionTarget target(getContext());
210-
target
211-
.addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect,
212-
LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
140+
target.addLegalDialect<BuiltinDialect, func::FuncDialect,
141+
vector::VectorDialect, LLVM::LLVMDialect>();
213142
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
214143
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
215144
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,

mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir

Lines changed: 1 addition & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx803})' | FileCheck %s --check-prefix=PRE9
2-
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx942})' | FileCheck %s --check-prefix=POST9
1+
// RUN: mlir-opt %s -convert-math-to-rocdl -allow-unregistered-dialect -split-input-file | FileCheck %s
32

43
module @test_module {
54
// CHECK: llvm.func @__ocml_fmod_f16(f16, f16) -> f16
@@ -597,76 +596,3 @@ module @test_module {
597596
func.return %result : vector<2x2xf16>
598597
}
599598
}
600-
601-
// -----
602-
603-
// f16 clamp → rocdl.fmed3 on gfx9+
604-
// CHECK-LABEL: func.func @clampf_f16
605-
func.func @clampf_f16(%x: f16, %lo: f16, %hi: f16) -> f16 {
606-
%r = math.clampf %x to [%lo, %hi] : f16
607-
return %r : f16
608-
// POST9: rocdl.fmed3 {{.*}} : f16
609-
// PRE9-NOT: rocdl.fmed3
610-
// PRE9: math.clampf {{.*}} : f16
611-
}
612-
613-
// f32 clamp → rocdl.fmed3 on gfx9+
614-
// CHECK-LABEL: func.func @clampf_f32
615-
func.func @clampf_f32(%x: f32, %lo: f32, %hi: f32) -> f32 {
616-
%r = math.clampf %x to [%lo, %hi] : f32
617-
return %r : f32
618-
// POST9: rocdl.fmed3 {{.*}} : f32
619-
// PRE9-NOT: rocdl.fmed3
620-
// PRE9: math.clampf {{.*}} : f32
621-
}
622-
623-
// -----
624-
625-
// Vector f16 clamp → rocdl.fmed3 on gfx9+
626-
// CHECK-LABEL: func.func @clampf_vector_f16
627-
func.func @clampf_vector_f16(%x: vector<2xf16>, %lo: vector<2xf16>, %hi: vector<2xf16>) -> vector<2xf16> {
628-
%r = math.clampf %x to [%lo, %hi] : vector<2xf16>
629-
return %r : vector<2xf16>
630-
// POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
631-
// PRE9-NOT: rocdl.fmed3
632-
// PRE9: math.clampf {{.*}} : vector<2xf16>
633-
}
634-
635-
// -----
636-
637-
// Vector f32 clamp → rocdl.fmed3 on gfx9+
638-
// CHECK-LABEL: func.func @clampf_vector_f32
639-
func.func @clampf_vector_f32(%x: vector<2xf32>, %lo: vector<2xf32>, %hi: vector<2xf32>) -> vector<2xf32> {
640-
%r = math.clampf %x to [%lo, %hi] : vector<2xf32>
641-
return %r : vector<2xf32>
642-
// POST9: rocdl.fmed3 {{.*}} : vector<2xf32>
643-
// PRE9-NOT: rocdl.fmed3
644-
// PRE9: math.clampf {{.*}} : vector<2xf32>
645-
}
646-
647-
// -----
648-
649-
// Multi-dimensional vector f16 clamp → rocdl.fmed3 on gfx9+ (unrolled to 1D vectors)
650-
// CHECK-LABEL: func.func @clampf_vector_2d_f16
651-
func.func @clampf_vector_2d_f16(%x: vector<2x2xf16>, %lo: vector<2x2xf16>, %hi: vector<2x2xf16>) -> vector<2x2xf16> {
652-
%r = math.clampf %x to [%lo, %hi] : vector<2x2xf16>
653-
return %r : vector<2x2xf16>
654-
// POST9: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>>
655-
// POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
656-
// POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
657-
// POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
658-
// POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
659-
// POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
660-
// POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
661-
// PRE9-NOT: rocdl.fmed3
662-
// PRE9: math.clampf {{.*}} : vector<2x2xf16>
663-
}
664-
665-
// -----
666-
// CHECK-LABEL: func.func @clampf_bf16
667-
func.func @clampf_bf16(%x: bf16, %lo: bf16, %hi: bf16) -> bf16 {
668-
%r = math.clampf %x to [%lo, %hi] : bf16
669-
return %r : bf16
670-
// CHECK: math.clampf {{.*}} : bf16
671-
// CHECK-NOT: rocdl.fmed3
672-
}

0 commit comments

Comments
 (0)