Skip to content

Commit a3cd7ef

Browse files
committed
Add cgroup mem parameters in createTarget
1 parent 03db991 commit a3cd7ef

File tree

2 files changed

+23
-12
lines changed

2 files changed

+23
-12
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2446,7 +2446,7 @@ class OpenMPIRBuilder {
24462446
/// The number of threads.
24472447
ArrayRef<Value *> NumThreads;
24482448
/// The size of the dynamic shared memory.
2449-
Value *DynCGGroupMem = nullptr;
2449+
Value *DynCGroupMem = nullptr;
24502450
/// True if the kernel has 'no wait' clause.
24512451
bool HasNoWait = false;
24522452
/// The fallback mechanism for the shared memory.
@@ -2457,12 +2457,12 @@ class OpenMPIRBuilder {
24572457
TargetKernelArgs() {}
24582458
TargetKernelArgs(unsigned NumTargetItems, TargetDataRTArgs RTArgs,
24592459
Value *NumIterations, ArrayRef<Value *> NumTeams,
2460-
ArrayRef<Value *> NumThreads, Value *DynCGGroupMem,
2460+
ArrayRef<Value *> NumThreads, Value *DynCGroupMem,
24612461
bool HasNoWait,
24622462
omp::OMPDynGroupprivateFallbackType DynCGroupMemFallback)
24632463
: NumTargetItems(NumTargetItems), RTArgs(RTArgs),
24642464
NumIterations(NumIterations), NumTeams(NumTeams),
2465-
NumThreads(NumThreads), DynCGGroupMem(DynCGGroupMem),
2465+
NumThreads(NumThreads), DynCGroupMem(DynCGroupMem),
24662466
HasNoWait(HasNoWait), DynCGroupMemFallback(DynCGroupMemFallback) {}
24672467
};
24682468

@@ -3248,6 +3248,10 @@ class OpenMPIRBuilder {
32483248
/// dependency information as passed in the depend clause
32493249
/// \param HasNowait Whether the target construct has a `nowait` clause or
32503250
/// not.
3251+
/// \param DynCGroupMem The size of the dynamic groupprivate memory for each
3252+
/// cgroup.
3253+
/// \param DynCGroupMem The fallback mechanism to execute if the requested
3254+
/// cgroup memory cannot be provided.
32513255
LLVM_ABI InsertPointOrErrorTy createTarget(
32523256
const LocationDescription &Loc, bool IsOffloadEntry,
32533257
OpenMPIRBuilder::InsertPointTy AllocaIP,
@@ -3259,7 +3263,10 @@ class OpenMPIRBuilder {
32593263
TargetBodyGenCallbackTy BodyGenCB,
32603264
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
32613265
CustomMapperCallbackTy CustomMapperCB,
3262-
const SmallVector<DependData> &Dependencies, bool HasNowait = false);
3266+
const SmallVector<DependData> &Dependencies, bool HasNowait = false,
3267+
Value *DynCGroupMem = nullptr,
3268+
omp::OMPDynGroupprivateFallbackType DynCGroupMemFallback =
3269+
omp::OMPDynGroupprivateFallbackType::Abort);
32633270

32643271
/// Returns __kmpc_for_static_init_* runtime function for the specified
32653272
/// size \a IVSize and sign \a IVSigned. Will create a distribute call

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
565565
Flags,
566566
NumTeams3D,
567567
NumThreads3D,
568-
KernelArgs.DynCGGroupMem};
568+
KernelArgs.DynCGroupMem};
569569
}
570570

571571
void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
@@ -8229,7 +8229,8 @@ static void emitTargetCall(
82298229
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
82308230
OpenMPIRBuilder::CustomMapperCallbackTy CustomMapperCB,
82318231
const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
8232-
bool HasNoWait) {
8232+
bool HasNoWait, Value *DynCGroupMem,
8233+
OMPDynGroupprivateFallbackType DynCGroupMemFallback) {
82338234
// Generate a function call to the host fallback implementation of the target
82348235
// region. This is called by the host when no offload entry was generated for
82358236
// the target region and when the offloading call fails at runtime.
@@ -8365,12 +8366,13 @@ static void emitTargetCall(
83658366
/*isSigned=*/false)
83668367
: Builder.getInt64(0);
83678368

8368-
// TODO: Use correct DynCGGroupMem
8369-
Value *DynCGGroupMem = Builder.getInt32(0);
8369+
// Request zero groupprivate bytes by default.
8370+
if (!DynCGroupMem)
8371+
DynCGroupMem = Builder.getInt32(0);
83708372

83718373
KArgs = OpenMPIRBuilder::TargetKernelArgs(
8372-
NumTargetItems, RTArgs, TripCount, NumTeamsC, NumThreadsC,
8373-
DynCGGroupMem, HasNoWait, OMPDynGroupprivateFallbackType::Abort);
8374+
NumTargetItems, RTArgs, TripCount, NumTeamsC, NumThreadsC, DynCGroupMem,
8375+
HasNoWait, DynCGroupMemFallback);
83748376

83758377
// Assume no error was returned because TaskBodyCB and
83768378
// EmitTargetCallFallbackCB don't produce any.
@@ -8419,7 +8421,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
84198421
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
84208422
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
84218423
CustomMapperCallbackTy CustomMapperCB,
8422-
const SmallVector<DependData> &Dependencies, bool HasNowait) {
8424+
const SmallVector<DependData> &Dependencies, bool HasNowait,
8425+
Value *DynCGroupMem, OMPDynGroupprivateFallbackType DynCGroupMemFallback) {
84238426

84248427
if (!updateToLocation(Loc))
84258428
return InsertPointTy();
@@ -8442,7 +8445,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
84428445
if (!Config.isTargetDevice())
84438446
emitTargetCall(*this, Builder, AllocaIP, Info, DefaultAttrs, RuntimeAttrs,
84448447
IfCond, OutlinedFn, OutlinedFnID, Inputs, GenMapInfoCB,
8445-
CustomMapperCB, Dependencies, HasNowait);
8448+
CustomMapperCB, Dependencies, HasNowait, DynCGroupMem,
8449+
DynCGroupMemFallback);
84468450
return Builder.saveIP();
84478451
}
84488452

0 commit comments

Comments
 (0)