Skip to content

Commit 48dca1e

Browse files
authored
[OpenACC][CIR] Implement 'atomic capture' lowering (#168422)
The 'atomic capture' variant of the `atomic` construct accepts either a single statement, or a compound statement containing two statements. Each of the statements it accepts meet a form of the previous read/write/update forms, or is a combination of two. The IR node for atomic capture takes two separate other acc.atomics, plus a terminator. This patch implements all of the lowering for these. Note: This gets the postfix-increment/decrement wrong, but the effort to do so is enough that I believe we can do that in a followup patch, so I'll be doing so in the next patch.
1 parent 2fc42c7 commit 48dca1e

File tree

5 files changed

+914
-121
lines changed

5 files changed

+914
-121
lines changed

clang/include/clang/AST/StmtOpenACC.h

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -818,14 +818,57 @@ class OpenACCAtomicConstruct final
818818

819819
// A struct to represent a broken-down version of the associated statement,
820820
// providing the information specified in OpenACC3.3 Section 2.12.
821-
struct StmtInfo {
821+
struct SingleStmtInfo {
822+
// Holds the entire expression for this. In the case of a normal
823+
// read/write/update, this should just be the associated statement. in the
824+
// case of an update, this is going to be the sub-expression this
825+
// represents.
826+
const Expr *WholeExpr;
822827
const Expr *V;
823828
const Expr *X;
824829
// Listed as 'expr' in the standard, this is typically a generic expression
825830
// as a component.
826831
const Expr *RefExpr;
827-
// TODO: OpenACC: We should expand this as we're implementing the other
828-
// atomic construct kinds.
832+
static SingleStmtInfo Empty() {
833+
return {nullptr, nullptr, nullptr, nullptr};
834+
}
835+
836+
static SingleStmtInfo createRead(const Expr *WholeExpr, const Expr *V,
837+
const Expr *X) {
838+
return {WholeExpr, V, X, /*RefExpr=*/nullptr};
839+
}
840+
static SingleStmtInfo createWrite(const Expr *WholeExpr, const Expr *X,
841+
const Expr *RefExpr) {
842+
return {WholeExpr, /*V=*/nullptr, X, RefExpr};
843+
}
844+
static SingleStmtInfo createUpdate(const Expr *WholeExpr, const Expr *X) {
845+
return {WholeExpr, /*V=*/nullptr, X, /*RefExpr=*/nullptr};
846+
}
847+
};
848+
849+
struct StmtInfo {
850+
enum class StmtForm {
851+
Read,
852+
Write,
853+
Update,
854+
ReadWrite,
855+
ReadUpdate,
856+
UpdateRead
857+
} Form;
858+
SingleStmtInfo First, Second;
859+
860+
static StmtInfo createUpdateRead(SingleStmtInfo First,
861+
SingleStmtInfo Second) {
862+
return {StmtForm::UpdateRead, First, Second};
863+
}
864+
static StmtInfo createReadWrite(SingleStmtInfo First,
865+
SingleStmtInfo Second) {
866+
return {StmtForm::ReadWrite, First, Second};
867+
}
868+
static StmtInfo createReadUpdate(SingleStmtInfo First,
869+
SingleStmtInfo Second) {
870+
return {StmtForm::ReadUpdate, First, Second};
871+
}
829872
};
830873

831874
const StmtInfo getAssociatedStmtInfo() const;

clang/lib/AST/StmtOpenACC.cpp

Lines changed: 222 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -324,30 +324,220 @@ OpenACCAtomicConstruct *OpenACCAtomicConstruct::Create(
324324
return Inst;
325325
}
326326

327-
static std::pair<const Expr *, const Expr *> getBinaryOpArgs(const Expr *Op) {
327+
static std::optional<std::pair<const Expr *, const Expr *>>
328+
getBinaryAssignOpArgs(const Expr *Op, bool &IsCompoundAssign) {
328329
if (const auto *BO = dyn_cast<BinaryOperator>(Op)) {
329-
assert(BO->isAssignmentOp());
330-
return {BO->getLHS(), BO->getRHS()};
330+
if (!BO->isAssignmentOp())
331+
return std::nullopt;
332+
IsCompoundAssign = BO->isCompoundAssignmentOp();
333+
return std::pair<const Expr *, const Expr *>(BO->getLHS(), BO->getRHS());
331334
}
332335

333-
const auto *OO = cast<CXXOperatorCallExpr>(Op);
334-
assert(OO->isAssignmentOp());
335-
return {OO->getArg(0), OO->getArg(1)};
336+
if (const auto *OO = dyn_cast<CXXOperatorCallExpr>(Op)) {
337+
if (!OO->isAssignmentOp())
338+
return std::nullopt;
339+
IsCompoundAssign = OO->getOperator() != OO_Equal;
340+
return std::pair<const Expr *, const Expr *>(OO->getArg(0), OO->getArg(1));
341+
}
342+
return std::nullopt;
343+
}
344+
static std::optional<std::pair<const Expr *, const Expr *>>
345+
getBinaryAssignOpArgs(const Expr *Op) {
346+
bool IsCompoundAssign;
347+
return getBinaryAssignOpArgs(Op, IsCompoundAssign);
336348
}
337349

338-
static std::pair<bool, const Expr *> getUnaryOpArgs(const Expr *Op) {
350+
static std::optional<const Expr *> getUnaryOpArgs(const Expr *Op) {
339351
if (const auto *UO = dyn_cast<UnaryOperator>(Op))
340-
return {true, UO->getSubExpr()};
352+
return UO->getSubExpr();
341353

342354
if (const auto *OpCall = dyn_cast<CXXOperatorCallExpr>(Op)) {
343355
// Post-inc/dec have a second unused argument to differentiate it, so we
344356
// accept -- or ++ as unary, or any operator call with only 1 arg.
345-
if (OpCall->getNumArgs() == 1 || OpCall->getOperator() != OO_PlusPlus ||
346-
OpCall->getOperator() != OO_MinusMinus)
347-
return {true, OpCall->getArg(0)};
357+
if (OpCall->getNumArgs() == 1 || OpCall->getOperator() == OO_PlusPlus ||
358+
OpCall->getOperator() == OO_MinusMinus)
359+
return {OpCall->getArg(0)};
348360
}
349361

350-
return {false, nullptr};
362+
return std::nullopt;
363+
}
364+
365+
// Read is of the form `v = x;`, where both sides are scalar L-values. This is a
366+
// BinaryOperator or CXXOperatorCallExpr.
367+
static std::optional<OpenACCAtomicConstruct::SingleStmtInfo>
368+
getReadStmtInfo(const Expr *E, bool ForAtomicComputeSingleStmt = false) {
369+
std::optional<std::pair<const Expr *, const Expr *>> BinaryArgs =
370+
getBinaryAssignOpArgs(E);
371+
372+
if (!BinaryArgs)
373+
return std::nullopt;
374+
375+
// We want the L-value for each side, so we ignore implicit casts.
376+
auto Res = OpenACCAtomicConstruct::SingleStmtInfo::createRead(
377+
E, BinaryArgs->first->IgnoreImpCasts(),
378+
BinaryArgs->second->IgnoreImpCasts());
379+
380+
// The atomic compute single-stmt variant has to do a 'fixup' step for the 'X'
381+
// value, since it is dependent on the RHS. So if we're in that version, we
382+
// skip the checks on X.
383+
if ((!ForAtomicComputeSingleStmt &&
384+
(!Res.X->isLValue() || !Res.X->getType()->isScalarType())) ||
385+
!Res.V->isLValue() || !Res.V->getType()->isScalarType())
386+
return std::nullopt;
387+
388+
return Res;
389+
}
390+
391+
// Write supports only the format 'x = expr', where the expression is scalar
392+
// type, and 'x' is a scalar l value. As above, this can come in 2 forms;
393+
// Binary Operator or CXXOperatorCallExpr.
394+
static std::optional<OpenACCAtomicConstruct::SingleStmtInfo>
395+
getWriteStmtInfo(const Expr *E) {
396+
std::optional<std::pair<const Expr *, const Expr *>> BinaryArgs =
397+
getBinaryAssignOpArgs(E);
398+
if (!BinaryArgs)
399+
return std::nullopt;
400+
// We want the L-value for ONLY the X side, so we ignore implicit casts. For
401+
// the right side (the expr), we emit it as an r-value so we need to
402+
// maintain implicit casts.
403+
auto Res = OpenACCAtomicConstruct::SingleStmtInfo::createWrite(
404+
E, BinaryArgs->first->IgnoreImpCasts(), BinaryArgs->second);
405+
406+
if (!Res.X->isLValue() || !Res.X->getType()->isScalarType())
407+
return std::nullopt;
408+
return Res;
409+
}
410+
411+
static std::optional<OpenACCAtomicConstruct::SingleStmtInfo>
412+
getUpdateStmtInfo(const Expr *E) {
413+
std::optional<const Expr *> UnaryArgs = getUnaryOpArgs(E);
414+
if (UnaryArgs) {
415+
auto Res = OpenACCAtomicConstruct::SingleStmtInfo::createUpdate(
416+
E, (*UnaryArgs)->IgnoreImpCasts());
417+
418+
if (!Res.X->isLValue() || !Res.X->getType()->isScalarType())
419+
return std::nullopt;
420+
421+
return Res;
422+
}
423+
424+
bool IsRHSCompoundAssign = false;
425+
std::optional<std::pair<const Expr *, const Expr *>> BinaryArgs =
426+
getBinaryAssignOpArgs(E, IsRHSCompoundAssign);
427+
if (!BinaryArgs)
428+
return std::nullopt;
429+
430+
auto Res = OpenACCAtomicConstruct::SingleStmtInfo::createUpdate(
431+
E, BinaryArgs->first->IgnoreImpCasts());
432+
433+
if (!Res.X->isLValue() || !Res.X->getType()->isScalarType())
434+
return std::nullopt;
435+
436+
// 'update' has to be either a compound-assignment operation, or
437+
// assignment-to-a-binary-op. Return nullopt if these are not the case.
438+
// If we are already compound-assign, we're done!
439+
if (IsRHSCompoundAssign)
440+
return Res;
441+
442+
// else we have to check that we have a binary operator.
443+
const Expr *RHS = BinaryArgs->second->IgnoreImpCasts();
444+
445+
if (isa<BinaryOperator>(RHS)) {
446+
return Res;
447+
} else if (const auto *OO = dyn_cast<CXXOperatorCallExpr>(RHS)) {
448+
if (OO->isInfixBinaryOp())
449+
return Res;
450+
}
451+
452+
return std::nullopt;
453+
}
454+
455+
/// The statement associated with an atomic capture comes in 1 of two forms: A
456+
/// compound statement containing two statements, or a single statement. In
457+
/// either case, the compound/single statement is decomposed into 2 separate
458+
/// operations, eihter a read/write, read/update, or update/read. This function
459+
/// figures out that information in the form listed in the standard (filling in
460+
/// V, X, or Expr) for each of these operations.
461+
static OpenACCAtomicConstruct::StmtInfo
462+
getCaptureStmtInfo(const Stmt *AssocStmt) {
463+
464+
if (const auto *CmpdStmt = dyn_cast<CompoundStmt>(AssocStmt)) {
465+
// We checked during Sema to ensure we only have 2 statements here, and
466+
// that both are expressions, we can look at these to see what the valid
467+
// options are.
468+
const Expr *Stmt1 = cast<Expr>(*CmpdStmt->body().begin())->IgnoreImpCasts();
469+
const Expr *Stmt2 =
470+
cast<Expr>(*(CmpdStmt->body().begin() + 1))->IgnoreImpCasts();
471+
472+
// The compound statement form allows read/write, read/update, or
473+
// update/read. First we get the information for a 'Read' to see if this is
474+
// one of the former two.
475+
std::optional<OpenACCAtomicConstruct::SingleStmtInfo> Read =
476+
getReadStmtInfo(Stmt1);
477+
478+
if (Read) {
479+
// READ : WRITE
480+
// v = x; x = expr
481+
// READ : UPDATE
482+
// v = x; x binop = expr
483+
// v = x; x = x binop expr
484+
// v = x; x = expr binop x
485+
// v = x; x++
486+
// v = x; ++x
487+
// v = x; x--
488+
// v = x; --x
489+
std::optional<OpenACCAtomicConstruct::SingleStmtInfo> Update =
490+
getUpdateStmtInfo(Stmt2);
491+
// Since we already know the first operation is a read, the second is
492+
// either an update, which we check, or a write, which we can assume next.
493+
if (Update)
494+
return OpenACCAtomicConstruct::StmtInfo::createReadUpdate(*Read,
495+
*Update);
496+
497+
std::optional<OpenACCAtomicConstruct::SingleStmtInfo> Write =
498+
getWriteStmtInfo(Stmt2);
499+
return OpenACCAtomicConstruct::StmtInfo::createReadWrite(*Read, *Write);
500+
}
501+
// UPDATE: READ
502+
// x binop = expr; v = x
503+
// x = x binop expr; v = x
504+
// x = expr binop x ; v = x
505+
// ++ x; v = x
506+
// x++; v = x
507+
// --x; v = x
508+
// x--; v = x
509+
// Otherwise, it is one of the above forms for update/read.
510+
std::optional<OpenACCAtomicConstruct::SingleStmtInfo> Update =
511+
getUpdateStmtInfo(Stmt1);
512+
Read = getReadStmtInfo(Stmt2);
513+
514+
return OpenACCAtomicConstruct::StmtInfo::createUpdateRead(*Update, *Read);
515+
} else {
516+
// All of the possible forms (listed below) that are writable as a single
517+
// line are expressed as an update, then as a read. We should be able to
518+
// just run these two in the right order.
519+
// UPDATE: READ
520+
// v = x++;
521+
// v = x--;
522+
// v = ++x;
523+
// v = --x;
524+
// v = x binop=expr
525+
// v = x = x binop expr
526+
// v = x = expr binop x
527+
528+
const Expr *E = cast<const Expr>(AssocStmt);
529+
530+
std::optional<OpenACCAtomicConstruct::SingleStmtInfo> Read =
531+
getReadStmtInfo(E, /*ForAtomicComputeSingleStmt=*/true);
532+
std::optional<OpenACCAtomicConstruct::SingleStmtInfo> Update =
533+
getUpdateStmtInfo(Read->X);
534+
535+
// Fixup this, since the 'X' for the read is the result after write, but is
536+
// the same value as the LHS-most variable of the update(its X).
537+
Read->X = Update->X;
538+
return OpenACCAtomicConstruct::StmtInfo::createUpdateRead(*Update, *Read);
539+
}
540+
return {};
351541
}
352542

353543
const OpenACCAtomicConstruct::StmtInfo
@@ -357,48 +547,28 @@ OpenACCAtomicConstruct::getAssociatedStmtInfo() const {
357547
// asserts to ensure we don't get off into the weeds.
358548
assert(getAssociatedStmt() && "invalid associated stmt?");
359549

360-
const Expr *AssocStmt = cast<const Expr>(getAssociatedStmt());
361550
switch (AtomicKind) {
362-
case OpenACCAtomicKind::Capture:
363-
assert(false && "Only 'read'/'write'/'update' have been implemented here");
364-
return {};
365-
case OpenACCAtomicKind::Read: {
366-
// Read only supports the format 'v = x'; where both sides are a scalar
367-
// expression. This can come in 2 forms; BinaryOperator or
368-
// CXXOperatorCallExpr (rarely).
369-
std::pair<const Expr *, const Expr *> BinaryArgs =
370-
getBinaryOpArgs(AssocStmt);
371-
// We want the L-value for each side, so we ignore implicit casts.
372-
return {BinaryArgs.first->IgnoreImpCasts(),
373-
BinaryArgs.second->IgnoreImpCasts(), /*expr=*/nullptr};
374-
}
375-
case OpenACCAtomicKind::Write: {
376-
// Write supports only the format 'x = expr', where the expression is scalar
377-
// type, and 'x' is a scalar l value. As above, this can come in 2 forms;
378-
// Binary Operator or CXXOperatorCallExpr.
379-
std::pair<const Expr *, const Expr *> BinaryArgs =
380-
getBinaryOpArgs(AssocStmt);
381-
// We want the L-value for ONLY the X side, so we ignore implicit casts. For
382-
// the right side (the expr), we emit it as an r-value so we need to
383-
// maintain implicit casts.
384-
return {/*v=*/nullptr, BinaryArgs.first->IgnoreImpCasts(),
385-
BinaryArgs.second};
386-
}
551+
case OpenACCAtomicKind::Read:
552+
return OpenACCAtomicConstruct::StmtInfo{
553+
OpenACCAtomicConstruct::StmtInfo::StmtForm::Read,
554+
*getReadStmtInfo(cast<const Expr>(getAssociatedStmt())),
555+
OpenACCAtomicConstruct::SingleStmtInfo::Empty()};
556+
557+
case OpenACCAtomicKind::Write:
558+
return OpenACCAtomicConstruct::StmtInfo{
559+
OpenACCAtomicConstruct::StmtInfo::StmtForm::Write,
560+
*getWriteStmtInfo(cast<const Expr>(getAssociatedStmt())),
561+
OpenACCAtomicConstruct::SingleStmtInfo::Empty()};
562+
387563
case OpenACCAtomicKind::None:
388-
case OpenACCAtomicKind::Update: {
389-
std::pair<bool, const Expr *> UnaryArgs = getUnaryOpArgs(AssocStmt);
390-
if (UnaryArgs.first)
391-
return {/*v=*/nullptr, UnaryArgs.second->IgnoreImpCasts(),
392-
/*expr=*/nullptr};
393-
394-
std::pair<const Expr *, const Expr *> BinaryArgs =
395-
getBinaryOpArgs(AssocStmt);
396-
// For binary args, we just store the RHS as an expression (in the
397-
// expression slot), since the codegen just wants the whole thing for a
398-
// recipe.
399-
return {/*v=*/nullptr, BinaryArgs.first->IgnoreImpCasts(),
400-
BinaryArgs.second};
401-
}
564+
case OpenACCAtomicKind::Update:
565+
return OpenACCAtomicConstruct::StmtInfo{
566+
OpenACCAtomicConstruct::StmtInfo::StmtForm::Update,
567+
*getUpdateStmtInfo(cast<const Expr>(getAssociatedStmt())),
568+
OpenACCAtomicConstruct::SingleStmtInfo::Empty()};
569+
570+
case OpenACCAtomicKind::Capture:
571+
return getCaptureStmtInfo(getAssociatedStmt());
402572
}
403573

404574
llvm_unreachable("unknown OpenACC atomic kind");

0 commit comments

Comments
 (0)