1010#include " mlir/Conversion/GPUCommon/GPUCommonPass.h"
1111#include " mlir/Conversion/LLVMCommon/LoweringOptions.h"
1212#include " mlir/Conversion/LLVMCommon/TypeConverter.h"
13+ #include " mlir/Conversion/LLVMCommon/VectorPattern.h"
14+ #include " mlir/Dialect/AMDGPU/Utils/Chipset.h"
1315#include " mlir/Dialect/Func/IR/FuncOps.h"
1416#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
1517#include " mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -42,8 +44,65 @@ static void populateOpPatterns(const LLVMTypeConverter &converter,
4244 f32ApproxFunc, f16Func);
4345}
4446
47+ struct ClampFOpConversion final
48+ : public ConvertOpToLLVMPattern<math::ClampFOp> {
49+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
50+ ClampFOpConversion (const LLVMTypeConverter &converter,
51+ amdgpu::Chipset chipset)
52+ : ConvertOpToLLVMPattern<math::ClampFOp>(converter), chipset(chipset) {}
53+
54+ LogicalResult
55+ matchAndRewrite (math::ClampFOp op, OpAdaptor adaptor,
56+ ConversionPatternRewriter &rewriter) const override {
57+ // Only f16 and f32 types are supported by fmed3
58+ Type opTy = op.getType ();
59+ auto resultType = getTypeConverter ()->convertType (opTy);
60+
61+ if (auto vectorType = dyn_cast<VectorType>(opTy)) {
62+ opTy = vectorType.getElementType ();
63+ }
64+
65+ if (!isa<Float16Type, Float32Type>(opTy)) {
66+ return rewriter.notifyMatchFailure (
67+ op, " fmed3 only supports f16 and f32 types" );
68+ }
69+
70+ // Handle multi-dimensional vectors (converted to LLVM arrays)
71+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType)) {
72+ // Handle multi-dimensional vectors (converted to LLVM arrays)
73+ return LLVM::detail::handleMultidimensionalVectors (
74+ op.getOperation (), adaptor.getOperands (), *getTypeConverter (),
75+ [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
76+ typename math::ClampFOp::Adaptor adaptor (operands);
77+ return ROCDL::FMed3Op::create (rewriter, op.getLoc (), llvm1DVectorTy,
78+ adaptor.getValue (), adaptor.getMin (),
79+ adaptor.getMax ());
80+ },
81+ rewriter);
82+ }
83+
84+ // Handle 1D vectors and scalars directly
85+ rewriter.replaceOpWithNewOp <ROCDL::FMed3Op>(op, op.getType (), op.getValue (),
86+ op.getMin (), op.getMax ());
87+ return success ();
88+ }
89+
90+ amdgpu::Chipset chipset;
91+ };
92+
93+ static void addChipsetDependentPatterns (const LLVMTypeConverter &converter,
94+ RewritePatternSet &patterns,
95+ amdgpu::Chipset chipset) {
96+
97+ // V_MED3_F16/F32 only exists in gfx9+ architectures
98+ if (chipset.majorVersion >= 9 ) {
99+ patterns.add <ClampFOpConversion>(converter, chipset);
100+ }
101+ }
102+
45103void mlir::populateMathToROCDLConversionPatterns (
46- const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
104+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
105+ amdgpu::Chipset chipset) {
47106 // Handled by mathToLLVM: math::AbsIOp
48107 // Handled by mathToLLVM: math::AbsFOp
49108 // Handled by mathToLLVM: math::CopySignOp
@@ -118,15 +177,17 @@ void mlir::populateMathToROCDLConversionPatterns(
118177 // worth creating a separate pass for it.
119178 populateOpPatterns<arith::RemFOp>(converter, patterns, " __ocml_fmod_f32" ,
120179 " __ocml_fmod_f64" , " __ocml_fmod_f16" );
180+
181+ addChipsetDependentPatterns (converter, patterns, chipset);
121182}
122183
123- namespace {
124- struct ConvertMathToROCDLPass
125- : public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
126- ConvertMathToROCDLPass () = default ;
184+ struct ConvertMathToROCDLPass final
185+ : impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
186+ using impl::ConvertMathToROCDLBase<
187+ ConvertMathToROCDLPass>::ConvertMathToROCDLBase;
188+
127189 void runOnOperation () override ;
128190};
129- } // namespace
130191
131192void ConvertMathToROCDLPass::runOnOperation () {
132193 auto m = getOperation ();
@@ -135,10 +196,20 @@ void ConvertMathToROCDLPass::runOnOperation() {
135196 RewritePatternSet patterns (&getContext ());
136197 LowerToLLVMOptions options (ctx, DataLayout (m));
137198 LLVMTypeConverter converter (ctx, options);
138- populateMathToROCDLConversionPatterns (converter, patterns);
199+
200+ // Only populate chipset-dependent patterns if chipset is specified
201+ if (!chipset.empty ()) {
202+ FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse (chipset);
203+ if (failed (maybeChipset)) {
204+ return signalPassFailure ();
205+ }
206+ populateMathToROCDLConversionPatterns (converter, patterns, *maybeChipset);
207+ }
208+
139209 ConversionTarget target (getContext ());
140- target.addLegalDialect <BuiltinDialect, func::FuncDialect,
141- vector::VectorDialect, LLVM::LLVMDialect>();
210+ target
211+ .addLegalDialect <BuiltinDialect, func::FuncDialect, vector::VectorDialect,
212+ LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
142213 target.addIllegalOp <LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
143214 LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
144215 LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
0 commit comments