1010#include " mlir/Conversion/GPUCommon/GPUCommonPass.h"
1111#include " mlir/Conversion/LLVMCommon/LoweringOptions.h"
1212#include " mlir/Conversion/LLVMCommon/TypeConverter.h"
13+ #include " mlir/Conversion/LLVMCommon/VectorPattern.h"
14+ #include " mlir/Dialect/AMDGPU/Utils/Chipset.h"
1315#include " mlir/Dialect/Func/IR/FuncOps.h"
1416#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
1517#include " mlir/Dialect/LLVMIR/ROCDLDialect.h"
1921#include " mlir/IR/PatternMatch.h"
2022#include " mlir/Pass/Pass.h"
2123#include " mlir/Transforms/DialectConversion.h"
24+ #include " llvm/Support/DebugLog.h"
2225
2326#include " ../GPUCommon/GPUOpsLowering.h"
2427#include " ../GPUCommon/OpToFuncCallLowering.h"
@@ -42,8 +45,46 @@ static void populateOpPatterns(const LLVMTypeConverter &converter,
4245 f32ApproxFunc, f16Func);
4346}
4447
48+ struct ClampFOpConversion final
49+ : public ConvertOpToLLVMPattern<math::ClampFOp> {
50+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
51+
52+ LogicalResult
53+ matchAndRewrite (math::ClampFOp op, OpAdaptor adaptor,
54+ ConversionPatternRewriter &rewriter) const override {
55+ // Only f16 and f32 types are supported by fmed3
56+ Type opTy = op.getType ();
57+ Type resultType = getTypeConverter ()->convertType (opTy);
58+
59+ if (auto vectorType = dyn_cast<VectorType>(opTy))
60+ opTy = vectorType.getElementType ();
61+
62+ if (!isa<Float16Type, Float32Type>(opTy))
63+ return rewriter.notifyMatchFailure (
64+ op, " fmed3 only supports f16 and f32 types" );
65+
66+ // Handle multi-dimensional vectors (converted to LLVM arrays)
67+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType))
68+ return LLVM::detail::handleMultidimensionalVectors (
69+ op.getOperation (), adaptor.getOperands (), *getTypeConverter (),
70+ [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
71+ typename math::ClampFOp::Adaptor adaptor (operands);
72+ return ROCDL::FMed3Op::create (rewriter, op.getLoc (), llvm1DVectorTy,
73+ adaptor.getValue (), adaptor.getMin (),
74+ adaptor.getMax ());
75+ },
76+ rewriter);
77+
78+ // Handle 1D vectors and scalars directly
79+ rewriter.replaceOpWithNewOp <ROCDL::FMed3Op>(op, op.getType (), op.getValue (),
80+ op.getMin (), op.getMax ());
81+ return success ();
82+ }
83+ };
84+
4585void mlir::populateMathToROCDLConversionPatterns (
46- const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
86+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
87+ std::optional<amdgpu::Chipset> chipset) {
4788 // Handled by mathToLLVM: math::AbsIOp
4889 // Handled by mathToLLVM: math::AbsFOp
4990 // Handled by mathToLLVM: math::CopySignOp
@@ -118,15 +159,21 @@ void mlir::populateMathToROCDLConversionPatterns(
118159 // worth creating a separate pass for it.
119160 populateOpPatterns<arith::RemFOp>(converter, patterns, " __ocml_fmod_f32" ,
120161 " __ocml_fmod_f64" , " __ocml_fmod_f16" );
162+
163+ if (chipset.has_value () && chipset->majorVersion >= 9 ) {
164+ patterns.add <ClampFOpConversion>(converter);
165+ } else {
166+ LDBG () << " Chipset dependent patterns were not added" ;
167+ }
121168}
122169
123- namespace {
124- struct ConvertMathToROCDLPass
125- : public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
126- ConvertMathToROCDLPass () = default ;
170+ struct ConvertMathToROCDLPass final
171+ : impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
172+ using impl::ConvertMathToROCDLBase<
173+ ConvertMathToROCDLPass>::ConvertMathToROCDLBase;
174+
127175 void runOnOperation () override ;
128176};
129- } // namespace
130177
131178void ConvertMathToROCDLPass::runOnOperation () {
132179 auto m = getOperation ();
@@ -135,10 +182,21 @@ void ConvertMathToROCDLPass::runOnOperation() {
135182 RewritePatternSet patterns (&getContext ());
136183 LowerToLLVMOptions options (ctx, DataLayout (m));
137184 LLVMTypeConverter converter (ctx, options);
138- populateMathToROCDLConversionPatterns (converter, patterns);
185+
186+ FailureOr<amdgpu::Chipset> maybeChipset;
187+ if (!chipset.empty ()) {
188+ maybeChipset = amdgpu::Chipset::parse (chipset);
189+ if (failed (maybeChipset))
190+ return signalPassFailure ();
191+ }
192+ populateMathToROCDLConversionPatterns (
193+ converter, patterns,
194+ succeeded (maybeChipset) ? std::optional (*maybeChipset) : std::nullopt );
195+
139196 ConversionTarget target (getContext ());
140- target.addLegalDialect <BuiltinDialect, func::FuncDialect,
141- vector::VectorDialect, LLVM::LLVMDialect>();
197+ target
198+ .addLegalDialect <BuiltinDialect, func::FuncDialect, vector::VectorDialect,
199+ LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
142200 target.addIllegalOp <LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
143201 LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
144202 LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
0 commit comments