1919#include " flang/Optimizer/HLFIR/HLFIROps.h"
2020#include " flang/Optimizer/HLFIR/Passes.h"
2121#include " mlir/Dialect/Arith/IR/Arith.h"
22- #include " mlir/Dialect/Func/IR/FuncOps.h"
23- #include " mlir/IR/BuiltinDialect.h"
2422#include " mlir/IR/Location.h"
2523#include " mlir/Pass/Pass.h"
26- #include " mlir/Transforms/DialectConversion .h"
24+ #include " mlir/Transforms/GreedyPatternRewriteDriver .h"
2725
2826namespace hlfir {
2927#define GEN_PASS_DEF_SIMPLIFYHLFIRINTRINSICS
@@ -45,9 +43,15 @@ class TransposeAsElementalConversion
4543 llvm::LogicalResult
4644 matchAndRewrite (hlfir::TransposeOp transpose,
4745 mlir::PatternRewriter &rewriter) const override {
46+ hlfir::ExprType expr = transpose.getType ();
47+ // TODO: hlfir.elemental supports polymorphic data types now,
48+ // so this can be supported.
49+ if (expr.isPolymorphic ())
50+ return rewriter.notifyMatchFailure (transpose,
51+ " TRANSPOSE of polymorphic type" );
52+
4853 mlir::Location loc = transpose.getLoc ();
4954 fir::FirOpBuilder builder{rewriter, transpose.getOperation ()};
50- hlfir::ExprType expr = transpose.getType ();
5155 mlir::Type elementType = expr.getElementType ();
5256 hlfir::Entity array = hlfir::Entity{transpose.getArray ()};
5357 mlir::Value resultShape = genResultShape (loc, builder, array);
@@ -105,15 +109,32 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
105109 llvm::LogicalResult
106110 matchAndRewrite (hlfir::SumOp sum,
107111 mlir::PatternRewriter &rewriter) const override {
112+ if (!simplifySum)
113+ return rewriter.notifyMatchFailure (sum, " SUM simplification is disabled" );
114+
115+ hlfir::Entity array = hlfir::Entity{sum.getArray ()};
116+ bool isTotalReduction = hlfir::Entity{sum}.getRank () == 0 ;
117+ mlir::Value dim = sum.getDim ();
118+ int64_t dimVal = 0 ;
119+ if (!isTotalReduction) {
120+ // In case of partial reduction we should ignore the operations
121+ // with invalid DIM values. They may appear in dead code
122+ // after constant propagation.
123+ auto constDim = fir::getIntIfConstant (dim);
124+ if (!constDim)
125+ return rewriter.notifyMatchFailure (sum, " Nonconstant DIM for SUM" );
126+ dimVal = *constDim;
127+
128+ if ((dimVal <= 0 || dimVal > array.getRank ()))
129+ return rewriter.notifyMatchFailure (
130+ sum, " Invalid DIM for partial SUM reduction" );
131+ }
132+
108133 mlir::Location loc = sum.getLoc ();
109134 fir::FirOpBuilder builder{rewriter, sum.getOperation ()};
110135 mlir::Type elementType = hlfir::getFortranElementType (sum.getType ());
111- hlfir::Entity array = hlfir::Entity{sum.getArray ()};
112136 mlir::Value mask = sum.getMask ();
113- mlir::Value dim = sum.getDim ();
114- bool isTotalReduction = hlfir::Entity{sum}.getRank () == 0 ;
115- int64_t dimVal =
116- isTotalReduction ? 0 : fir::getIntIfConstant (dim).value_or (0 );
137+
117138 mlir::Value resultShape, dimExtent;
118139 llvm::SmallVector<mlir::Value> arrayExtents;
119140 if (isTotalReduction)
@@ -360,27 +381,38 @@ class CShiftAsElementalConversion
360381public:
361382 using mlir::OpRewritePattern<hlfir::CShiftOp>::OpRewritePattern;
362383
363- explicit CShiftAsElementalConversion (mlir::MLIRContext *ctx)
364- : OpRewritePattern(ctx) {
365- setHasBoundedRewriteRecursion ();
366- }
367-
368384 llvm::LogicalResult
369385 matchAndRewrite (hlfir::CShiftOp cshift,
370386 mlir::PatternRewriter &rewriter) const override {
371387 using Fortran::common::maxRank;
372388
373- mlir::Location loc = cshift.getLoc ();
374- fir::FirOpBuilder builder{rewriter, cshift.getOperation ()};
375389 hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(cshift.getType ());
376390 assert (expr &&
377391 " expected an expression type for the result of hlfir.cshift" );
392+ unsigned arrayRank = expr.getRank ();
393+ // When it is a 1D CSHIFT, we may assume that the DIM argument
394+ // (whether it is present or absent) is equal to 1, otherwise,
395+ // the program is illegal.
396+ int64_t dimVal = 1 ;
397+ if (arrayRank != 1 )
398+ if (mlir::Value dim = cshift.getDim ()) {
399+ auto constDim = fir::getIntIfConstant (dim);
400+ if (!constDim)
401+ return rewriter.notifyMatchFailure (cshift,
402+ " Nonconstant DIM for CSHIFT" );
403+ dimVal = *constDim;
404+ }
405+
406+ if (dimVal <= 0 || dimVal > arrayRank)
407+ return rewriter.notifyMatchFailure (cshift, " Invalid DIM for CSHIFT" );
408+
409+ mlir::Location loc = cshift.getLoc ();
410+ fir::FirOpBuilder builder{rewriter, cshift.getOperation ()};
378411 mlir::Type elementType = expr.getElementType ();
379412 hlfir::Entity array = hlfir::Entity{cshift.getArray ()};
380413 mlir::Value arrayShape = hlfir::genShape (loc, builder, array);
381414 llvm::SmallVector<mlir::Value> arrayExtents =
382415 hlfir::getExplicitExtentsFromShape (arrayShape, builder);
383- unsigned arrayRank = expr.getRank ();
384416 llvm::SmallVector<mlir::Value, 1 > typeParams;
385417 hlfir::genLengthParameters (loc, builder, array, typeParams);
386418 hlfir::Entity shift = hlfir::Entity{cshift.getShift ()};
@@ -395,20 +427,6 @@ class CShiftAsElementalConversion
395427 shiftVal = builder.createConvert (loc, calcType, shiftVal);
396428 }
397429
398- int64_t dimVal = 1 ;
399- if (arrayRank == 1 ) {
400- // When it is a 1D CSHIFT, we may assume that the DIM argument
401- // (whether it is present or absent) is equal to 1, otherwise,
402- // the program is illegal.
403- assert (shiftVal && " SHIFT must be scalar" );
404- } else {
405- if (mlir::Value dim = cshift.getDim ())
406- dimVal = fir::getIntIfConstant (dim).value_or (0 );
407- assert (dimVal > 0 && dimVal <= arrayRank &&
408- " DIM must be present and a positive constant not exceeding "
409- " the array's rank" );
410- }
411-
412430 auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
413431 mlir::ValueRange inputIndices) -> hlfir::Entity {
414432 llvm::SmallVector<mlir::Value, maxRank> indices{inputIndices};
@@ -462,68 +480,19 @@ class SimplifyHLFIRIntrinsics
462480public:
463481 void runOnOperation () override {
464482 mlir::MLIRContext *context = &getContext ();
483+
484+ mlir::GreedyRewriteConfig config;
485+ // Prevent the pattern driver from merging blocks
486+ config.enableRegionSimplification =
487+ mlir::GreedySimplifyRegionLevel::Disabled;
488+
465489 mlir::RewritePatternSet patterns (context);
466490 patterns.insert <TransposeAsElementalConversion>(context);
467491 patterns.insert <SumAsElementalConversion>(context);
468492 patterns.insert <CShiftAsElementalConversion>(context);
469- mlir::ConversionTarget target (*context);
470- // don't transform transpose of polymorphic arrays (not currently supported
471- // by hlfir.elemental)
472- target.addDynamicallyLegalOp <hlfir::TransposeOp>(
473- [](hlfir::TransposeOp transpose) {
474- return mlir::cast<hlfir::ExprType>(transpose.getType ())
475- .isPolymorphic ();
476- });
477- // Handle only SUM(DIM=CONSTANT) case for now.
478- // It may be beneficial to expand the non-DIM case as well.
479- // E.g. when the input array is an elemental array expression,
480- // expanding the SUM into a total reduction loop nest
481- // would avoid creating a temporary for the elemental array expression.
482- target.addDynamicallyLegalOp <hlfir::SumOp>([](hlfir::SumOp sum) {
483- if (!simplifySum)
484- return true ;
485-
486- // Always inline total reductions.
487- if (hlfir::Entity{sum}.getRank () == 0 )
488- return false ;
489- mlir::Value dim = sum.getDim ();
490- if (!dim)
491- return false ;
492-
493- if (auto dimVal = fir::getIntIfConstant (dim)) {
494- fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
495- hlfir::getFortranElementOrSequenceType (sum.getArray ().getType ()));
496- if (*dimVal > 0 && *dimVal <= arrayTy.getDimension ()) {
497- // Ignore SUMs with illegal DIM values.
498- // They may appear in dead code,
499- // and they do not have to be converted.
500- return false ;
501- }
502- }
503- return true ;
504- });
505- target.addDynamicallyLegalOp <hlfir::CShiftOp>([](hlfir::CShiftOp cshift) {
506- unsigned resultRank = hlfir::Entity{cshift}.getRank ();
507- if (resultRank == 1 )
508- return false ;
509-
510- mlir::Value dim = cshift.getDim ();
511- if (!dim)
512- return false ;
513-
514- // If DIM is present, then it must be constant to please
515- // the conversion. In addition, ignore cases with
516- // illegal DIM values.
517- if (auto dimVal = fir::getIntIfConstant (dim))
518- if (*dimVal > 0 && *dimVal <= resultRank)
519- return false ;
520-
521- return true ;
522- });
523- target.markUnknownOpDynamicallyLegal (
524- [](mlir::Operation *) { return true ; });
525- if (mlir::failed (mlir::applyFullConversion (getOperation (), target,
526- std::move (patterns)))) {
493+
494+ if (mlir::failed (mlir::applyPatternsAndFoldGreedily (
495+ getOperation (), std::move (patterns), config))) {
527496 mlir::emitError (getOperation ()->getLoc (),
528497 " failure in HLFIR intrinsic simplification" );
529498 signalPassFailure ();
0 commit comments