@@ -1725,15 +1725,35 @@ struct EmboxOpConversion : public EmboxCommonConversion<fir::EmboxOp> {
17251725 }
17261726};
17271727
1728- static bool isDeviceAllocation (mlir::Value val) {
1728+ static bool isDeviceAllocation (mlir::Value val, mlir::Value adaptorVal ) {
17291729 if (auto loadOp = mlir::dyn_cast_or_null<fir::LoadOp>(val.getDefiningOp ()))
1730- return isDeviceAllocation (loadOp.getMemref ());
1730+ return isDeviceAllocation (loadOp.getMemref (), {} );
17311731 if (auto boxAddrOp =
17321732 mlir::dyn_cast_or_null<fir::BoxAddrOp>(val.getDefiningOp ()))
1733- return isDeviceAllocation (boxAddrOp.getVal ());
1733+ return isDeviceAllocation (boxAddrOp.getVal (), {} );
17341734 if (auto convertOp =
17351735 mlir::dyn_cast_or_null<fir::ConvertOp>(val.getDefiningOp ()))
1736- return isDeviceAllocation (convertOp.getValue ());
1736+ return isDeviceAllocation (convertOp.getValue (), {});
1737+ if (!val.getDefiningOp () && adaptorVal) {
1738+ if (auto blockArg = llvm::cast<mlir::BlockArgument>(adaptorVal)) {
1739+ if (blockArg.getOwner () && blockArg.getOwner ()->getParentOp () &&
1740+ blockArg.getOwner ()->isEntryBlock ()) {
1741+ if (auto func = mlir::dyn_cast_or_null<mlir::FunctionOpInterface>(
1742+ *blockArg.getOwner ()->getParentOp ())) {
1743+ auto argAttrs = func.getArgAttrs (blockArg.getArgNumber ());
1744+ for (auto attr : argAttrs) {
1745+ if (attr.getName ().getValue ().ends_with (cuf::getDataAttrName ())) {
1746+ auto dataAttr =
1747+ mlir::dyn_cast<cuf::DataAttributeAttr>(attr.getValue ());
1748+ if (dataAttr.getValue () != cuf::DataAttribute::Pinned &&
1749+ dataAttr.getValue () != cuf::DataAttribute::Unified)
1750+ return true ;
1751+ }
1752+ }
1753+ }
1754+ }
1755+ }
1756+ }
17371757 if (auto callOp = mlir::dyn_cast_or_null<fir::CallOp>(val.getDefiningOp ()))
17381758 if (callOp.getCallee () &&
17391759 (callOp.getCallee ().value ().getRootReference ().getValue ().starts_with (
@@ -1928,7 +1948,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
19281948 if (fir::isDerivedTypeWithLenParams (boxTy))
19291949 TODO (loc, " fir.embox codegen of derived with length parameters" );
19301950 mlir::Value result = placeInMemoryIfNotGlobalInit (
1931- rewriter, loc, boxTy, dest, isDeviceAllocation (xbox.getMemref ()));
1951+ rewriter, loc, boxTy, dest,
1952+ isDeviceAllocation (xbox.getMemref (), adaptor.getMemref ()));
19321953 rewriter.replaceOp (xbox, result);
19331954 return mlir::success ();
19341955 }
@@ -2052,9 +2073,9 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
20522073 dest = insertStride (rewriter, loc, dest, dim, std::get<1 >(iter.value ()));
20532074 }
20542075 dest = insertBaseAddress (rewriter, loc, dest, base);
2055- mlir::Value result =
2056- placeInMemoryIfNotGlobalInit ( rewriter, rebox.getLoc (), destBoxTy, dest,
2057- isDeviceAllocation (rebox.getBox ()));
2076+ mlir::Value result = placeInMemoryIfNotGlobalInit (
2077+ rewriter, rebox.getLoc (), destBoxTy, dest,
2078+ isDeviceAllocation (rebox. getBox (), rebox.getBox ()));
20582079 rewriter.replaceOp (rebox, result);
20592080 return mlir::success ();
20602081 }
0 commit comments