1010#include " mlir/Conversion/GPUCommon/GPUCommonPass.h"
1111#include " mlir/Conversion/LLVMCommon/LoweringOptions.h"
1212#include " mlir/Conversion/LLVMCommon/TypeConverter.h"
13- #include " mlir/Conversion/LLVMCommon/VectorPattern.h"
14- #include " mlir/Dialect/AMDGPU/Utils/Chipset.h"
1513#include " mlir/Dialect/Func/IR/FuncOps.h"
1614#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
1715#include " mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -44,65 +42,8 @@ static void populateOpPatterns(const LLVMTypeConverter &converter,
4442 f32ApproxFunc, f16Func);
4543}
4644
47- struct ClampFOpConversion final
48- : public ConvertOpToLLVMPattern<math::ClampFOp> {
49- using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
50- ClampFOpConversion (const LLVMTypeConverter &converter,
51- amdgpu::Chipset chipset)
52- : ConvertOpToLLVMPattern<math::ClampFOp>(converter), chipset(chipset) {}
53-
54- LogicalResult
55- matchAndRewrite (math::ClampFOp op, OpAdaptor adaptor,
56- ConversionPatternRewriter &rewriter) const override {
57- // Only f16 and f32 types are supported by fmed3
58- Type opTy = op.getType ();
59- auto resultType = getTypeConverter ()->convertType (opTy);
60-
61- if (auto vectorType = dyn_cast<VectorType>(opTy)) {
62- opTy = vectorType.getElementType ();
63- }
64-
65- if (!isa<Float16Type, Float32Type>(opTy)) {
66- return rewriter.notifyMatchFailure (
67- op, " fmed3 only supports f16 and f32 types" );
68- }
69-
70- // Handle multi-dimensional vectors (converted to LLVM arrays)
71- if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType)) {
72- // Handle multi-dimensional vectors (converted to LLVM arrays)
73- return LLVM::detail::handleMultidimensionalVectors (
74- op.getOperation (), adaptor.getOperands (), *getTypeConverter (),
75- [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
76- typename math::ClampFOp::Adaptor adaptor (operands);
77- return ROCDL::FMed3Op::create (rewriter, op.getLoc (), llvm1DVectorTy,
78- adaptor.getValue (), adaptor.getMin (),
79- adaptor.getMax ());
80- },
81- rewriter);
82- }
83-
84- // Handle 1D vectors and scalars directly
85- rewriter.replaceOpWithNewOp <ROCDL::FMed3Op>(op, op.getType (), op.getValue (),
86- op.getMin (), op.getMax ());
87- return success ();
88- }
89-
90- amdgpu::Chipset chipset;
91- };
92-
93- static void addChipsetDependentPatterns (const LLVMTypeConverter &converter,
94- RewritePatternSet &patterns,
95- amdgpu::Chipset chipset) {
96-
97- // V_MED3_F16/F32 only exists in gfx9+ architectures
98- if (chipset.majorVersion >= 9 ) {
99- patterns.add <ClampFOpConversion>(converter, chipset);
100- }
101- }
102-
10345void mlir::populateMathToROCDLConversionPatterns (
104- const LLVMTypeConverter &converter, RewritePatternSet &patterns,
105- amdgpu::Chipset chipset) {
46+ const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
10647 // Handled by mathToLLVM: math::AbsIOp
10748 // Handled by mathToLLVM: math::AbsFOp
10849 // Handled by mathToLLVM: math::CopySignOp
@@ -177,17 +118,15 @@ void mlir::populateMathToROCDLConversionPatterns(
177118 // worth creating a separate pass for it.
178119 populateOpPatterns<arith::RemFOp>(converter, patterns, " __ocml_fmod_f32" ,
179120 " __ocml_fmod_f64" , " __ocml_fmod_f16" );
180-
181- addChipsetDependentPatterns (converter, patterns, chipset);
182121}
183122
184- struct ConvertMathToROCDLPass final
185- : impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
186- using impl::ConvertMathToROCDLBase<
187- ConvertMathToROCDLPass>::ConvertMathToROCDLBase;
188-
123+ namespace {
124+ struct ConvertMathToROCDLPass
125+ : public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
126+ ConvertMathToROCDLPass () = default ;
189127 void runOnOperation () override ;
190128};
129+ } // namespace
191130
192131void ConvertMathToROCDLPass::runOnOperation () {
193132 auto m = getOperation ();
@@ -196,20 +135,10 @@ void ConvertMathToROCDLPass::runOnOperation() {
196135 RewritePatternSet patterns (&getContext ());
197136 LowerToLLVMOptions options (ctx, DataLayout (m));
198137 LLVMTypeConverter converter (ctx, options);
199-
200- // Only populate chipset-dependent patterns if chipset is specified
201- if (!chipset.empty ()) {
202- FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse (chipset);
203- if (failed (maybeChipset)) {
204- return signalPassFailure ();
205- }
206- populateMathToROCDLConversionPatterns (converter, patterns, *maybeChipset);
207- }
208-
138+ populateMathToROCDLConversionPatterns (converter, patterns);
209139 ConversionTarget target (getContext ());
210- target
211- .addLegalDialect <BuiltinDialect, func::FuncDialect, vector::VectorDialect,
212- LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
140+ target.addLegalDialect <BuiltinDialect, func::FuncDialect,
141+ vector::VectorDialect, LLVM::LLVMDialect>();
213142 target.addIllegalOp <LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
214143 LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
215144 LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
0 commit comments