diff --git a/include/dxc/DXIL/DxilFunctionProps.h b/include/dxc/DXIL/DxilFunctionProps.h index 425ec4e391..e0223e9dd5 100644 --- a/include/dxc/DXIL/DxilFunctionProps.h +++ b/include/dxc/DXIL/DxilFunctionProps.h @@ -117,6 +117,7 @@ struct DxilFunctionProps { memset(&Node, 0, sizeof(Node)); Node.LaunchType = DXIL::NodeLaunchType::Invalid; Node.LocalRootArgumentsTableIndex = -1; + groupSharedLimitBytes = 0; } union { // Geometry shader. @@ -174,6 +175,8 @@ struct DxilFunctionProps { // numThreads shared between multiple shader types and node shaders. unsigned numThreads[3]; + unsigned groupSharedLimitBytes; + struct NodeProps { DXIL::NodeLaunchType LaunchType = DXIL::NodeLaunchType::Invalid; bool IsProgramEntry; diff --git a/include/dxc/DXIL/DxilMetadataHelper.h b/include/dxc/DXIL/DxilMetadataHelper.h index e17db016d8..5132cedb44 100644 --- a/include/dxc/DXIL/DxilMetadataHelper.h +++ b/include/dxc/DXIL/DxilMetadataHelper.h @@ -320,6 +320,7 @@ class DxilMDHelper { static const unsigned kDxilNodeOutputsTag = 21; static const unsigned kDxilNodeMaxDispatchGridTag = 22; static const unsigned kDxilRangedWaveSizeTag = 23; + static const unsigned kDxilGroupSharedLimitTag = 24; // Node Input/Output State. static const unsigned kDxilNodeOutputIDTag = 0; diff --git a/include/dxc/DXIL/DxilModule.h b/include/dxc/DXIL/DxilModule.h index 3f1ba12f86..732ce27c54 100644 --- a/include/dxc/DXIL/DxilModule.h +++ b/include/dxc/DXIL/DxilModule.h @@ -254,6 +254,8 @@ class DxilModule { void SetNumThreads(unsigned x, unsigned y, unsigned z); unsigned GetNumThreads(unsigned idx) const; + unsigned GetGroupSharedLimit() const; + // Compute shader DxilWaveSize &GetWaveSize(); const DxilWaveSize &GetWaveSize() const; diff --git a/include/dxc/DxilContainer/DxilPipelineStateValidation.h b/include/dxc/DxilContainer/DxilPipelineStateValidation.h index 83d0dae6e9..2cf13f3a84 100644 --- a/include/dxc/DxilContainer/DxilPipelineStateValidation.h +++ b/include/dxc/DxilContainer/DxilPipelineStateValidation.h @@ -175,6 +175,10 @@ struct PSVRuntimeInfo3 : public PSVRuntimeInfo2 { uint32_t EntryFunctionName; }; +struct PSVRuntimeInfo4 : public PSVRuntimeInfo3 { + uint32_t GroupSharedLimit; +}; + enum class PSVResourceType { Invalid = 0, @@ -474,7 +478,7 @@ class PSVSignatureElement { const uint32_t *SemanticIndexes) const; }; -#define MAX_PSV_VERSION 3 +#define MAX_PSV_VERSION 4 struct PSVInitInfo { PSVInitInfo(uint32_t psvVersion) : PSVVersion(psvVersion) {} @@ -491,7 +495,7 @@ struct PSVInitInfo { uint8_t SigPatchConstOrPrimVectors = 0; uint8_t SigOutputVectors[PSV_GS_MAX_STREAMS] = {0, 0, 0, 0}; - static_assert(MAX_PSV_VERSION == 3, "otherwise this needs updating."); + static_assert(MAX_PSV_VERSION == 4, "otherwise this needs updating."); uint32_t RuntimeInfoSize() const { switch (PSVVersion) { case 0: @@ -500,10 +504,12 @@ struct PSVInitInfo { return sizeof(PSVRuntimeInfo1); case 2: return sizeof(PSVRuntimeInfo2); + case 3: + return sizeof(PSVRuntimeInfo3); default: break; } - return sizeof(PSVRuntimeInfo3); + return sizeof(PSVRuntimeInfo4); } uint32_t ResourceBindInfoSize() const { if (PSVVersion < 2) @@ -519,6 +525,7 @@ class DxilPipelineStateValidation { PSVRuntimeInfo1 *m_pPSVRuntimeInfo1 = nullptr; PSVRuntimeInfo2 *m_pPSVRuntimeInfo2 = nullptr; PSVRuntimeInfo3 *m_pPSVRuntimeInfo3 = nullptr; + PSVRuntimeInfo4 *m_pPSVRuntimeInfo4 = nullptr; uint32_t m_uResourceCount = 0; uint32_t m_uPSVResourceBindInfoSize = 0; void *m_pPSVResourceBindInfo = nullptr; @@ -634,6 +641,8 @@ class DxilPipelineStateValidation { PSVRuntimeInfo3 *GetPSVRuntimeInfo3() const { return m_pPSVRuntimeInfo3; } + PSVRuntimeInfo4 *GetPSVRuntimeInfo4() const { return m_pPSVRuntimeInfo4; } + uint32_t GetBindCount() const { return m_uResourceCount; } template @@ -949,6 +958,8 @@ DxilPipelineStateValidation::ReadOrWrite(const void *pBits, uint32_t *pSize, m_uPSVRuntimeInfoSize); // failure ok AssignDerived(&m_pPSVRuntimeInfo3, m_pPSVRuntimeInfo0, m_uPSVRuntimeInfoSize); // failure ok + AssignDerived(&m_pPSVRuntimeInfo4, m_pPSVRuntimeInfo0, + m_uPSVRuntimeInfoSize); // failure ok // In RWMode::CalcSize, use temp runtime info to hold needed values from // initInfo @@ -1137,11 +1148,13 @@ void SetupPSVInitInfo(PSVInitInfo &InitInfo, const DxilModule &DM); void SetShaderProps(PSVRuntimeInfo0 *pInfo, const DxilModule &DM); void SetShaderProps(PSVRuntimeInfo1 *pInfo1, const DxilModule &DM); void SetShaderProps(PSVRuntimeInfo2 *pInfo2, const DxilModule &DM); +void SetShaderProps(PSVRuntimeInfo4 *pInfo4, const DxilModule &DM); void PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0, PSVRuntimeInfo1 *pInfo1, PSVRuntimeInfo2 *pInfo2, - PSVRuntimeInfo3 *pInfo3, uint8_t ShaderKind, - const char *EntryName, const char *Comment); + PSVRuntimeInfo3 *pInfo3, PSVRuntimeInfo4 *pInfo4, + uint8_t ShaderKind, const char *EntryName, + const char *Comment); } // namespace hlsl diff --git a/lib/DXIL/DxilMetadataHelper.cpp b/lib/DXIL/DxilMetadataHelper.cpp index c1282a980a..a50f38802e 100644 --- a/lib/DXIL/DxilMetadataHelper.cpp +++ b/lib/DXIL/DxilMetadataHelper.cpp @@ -1624,6 +1624,10 @@ MDTuple *DxilMDHelper::EmitDxilEntryProperties(uint64_t rawShaderFlag, } MDVals.emplace_back(MDNode::get(m_Ctx, WaveSizeVal)); } + + MDVals.emplace_back( + Uint32ToConstMD(DxilMDHelper::kDxilGroupSharedLimitTag)); + MDVals.emplace_back(Uint32ToConstMD(props.groupSharedLimitBytes)); } break; // Geometry shader. case DXIL::ShaderKind::Geometry: { @@ -1773,6 +1777,11 @@ void DxilMDHelper::LoadDxilEntryProperties(const MDOperand &MDO, props.numThreads[2] = ConstMDToUint32(pNode->getOperand(2)); } break; + case DxilMDHelper::kDxilGroupSharedLimitTag: { + DXASSERT(props.IsCS(), "else invalid shader kind"); + props.groupSharedLimitBytes = ConstMDToUint32(MDO); + } break; + case DxilMDHelper::kDxilGSStateTag: { DXASSERT(props.IsGS(), "else invalid shader kind"); auto &GS = props.ShaderProps.GS; diff --git a/lib/DXIL/DxilModule.cpp b/lib/DXIL/DxilModule.cpp index f4abdd15aa..15b4079b4d 100644 --- a/lib/DXIL/DxilModule.cpp +++ b/lib/DXIL/DxilModule.cpp @@ -412,6 +412,15 @@ unsigned DxilModule::GetNumThreads(unsigned idx) const { return props.numThreads[idx]; } +unsigned DxilModule::GetGroupSharedLimit() const { + DXASSERT(m_DxilEntryPropsMap.size() == 1 && + (m_pSM->IsCS() || m_pSM->IsMS() || m_pSM->IsAS()), + "only works for CS/MS/AS profiles"); + const DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props; + DXASSERT_NOMSG(m_pSM->GetKind() == props.shaderKind); + return props.groupSharedLimitBytes; +} + DxilWaveSize &DxilModule::GetWaveSize() { return const_cast( static_cast(this)->GetWaveSize()); diff --git a/lib/DxilContainer/DxilContainerAssembler.cpp b/lib/DxilContainer/DxilContainerAssembler.cpp index 48d8872733..736940b325 100644 --- a/lib/DxilContainer/DxilContainerAssembler.cpp +++ b/lib/DxilContainer/DxilContainerAssembler.cpp @@ -798,6 +798,8 @@ class DxilPSVWriter : public DxilPartWriter { PSVRuntimeInfo1 *pInfo1 = m_PSV.GetPSVRuntimeInfo1(); PSVRuntimeInfo2 *pInfo2 = m_PSV.GetPSVRuntimeInfo2(); PSVRuntimeInfo3 *pInfo3 = m_PSV.GetPSVRuntimeInfo3(); + PSVRuntimeInfo4 *pInfo4 = m_PSV.GetPSVRuntimeInfo4(); + if (pInfo) hlsl::SetShaderProps(pInfo, m_Module); if (pInfo1) @@ -806,6 +808,8 @@ class DxilPSVWriter : public DxilPartWriter { hlsl::SetShaderProps(pInfo2, m_Module); if (pInfo3) pInfo3->EntryFunctionName = EntryFunctionName; + if (pInfo4) + hlsl::SetShaderProps(pInfo4, m_Module); // Set resource binding information UINT uResIndex = 0; diff --git a/lib/DxilContainer/DxilPipelineStateValidation.cpp b/lib/DxilContainer/DxilPipelineStateValidation.cpp index 66186549f2..06a4782c57 100644 --- a/lib/DxilContainer/DxilPipelineStateValidation.cpp +++ b/lib/DxilContainer/DxilPipelineStateValidation.cpp @@ -305,6 +305,20 @@ void hlsl::SetShaderProps(PSVRuntimeInfo2 *pInfo2, const DxilModule &DM) { } } +void hlsl::SetShaderProps(PSVRuntimeInfo4 *pInfo4, const DxilModule &DM) { + assert(pInfo4); + const ShaderModel *SM = DM.GetShaderModel(); + switch (SM->GetKind()) { + case ShaderModel::Kind::Compute: + case ShaderModel::Kind::Mesh: + case ShaderModel::Kind::Amplification: + pInfo4->GroupSharedLimit = DM.GetGroupSharedLimit(); + break; + default: + break; + } +} + void PSVResourceBindInfo0::Print(raw_ostream &OS) const { OS << "PSVResourceBindInfo:\n"; OS << " Space: " << Space << "\n"; @@ -584,8 +598,9 @@ void PSVDependencyTable::Print(raw_ostream &OS, const char *InputSetName, void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0, PSVRuntimeInfo1 *pInfo1, PSVRuntimeInfo2 *pInfo2, - PSVRuntimeInfo3 *pInfo3, uint8_t ShaderKind, - const char *EntryName, const char *Comment) { + PSVRuntimeInfo3 *pInfo3, PSVRuntimeInfo4 *pInfo4, + uint8_t ShaderKind, const char *EntryName, + const char *Comment) { if (pInfo1 && pInfo1->ShaderStage != ShaderKind) ShaderKind = pInfo1->ShaderStage; OS << Comment << "PSVRuntimeInfo:\n"; @@ -808,6 +823,9 @@ void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0, OS << Comment << " NumThreads=(" << pInfo2->NumThreadsX << "," << pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n"; } + if (pInfo4) { + OS << Comment << " GroupSharedLimit=" << pInfo4->GroupSharedLimit << "\n"; + } break; case PSVShaderKind::Amplification: OS << Comment << " Amplification Shader\n"; @@ -815,6 +833,9 @@ void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0, OS << Comment << " NumThreads=(" << pInfo2->NumThreadsX << "," << pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n"; } + if (pInfo4) { + OS << Comment << " GroupSharedLimit=" << pInfo4->GroupSharedLimit << "\n"; + } break; case PSVShaderKind::Mesh: OS << Comment << " Mesh Shader\n"; @@ -841,6 +862,9 @@ void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0, OS << Comment << " NumThreads=(" << pInfo2->NumThreadsX << "," << pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n"; } + if (pInfo4) { + OS << Comment << " GroupSharedLimit=" << pInfo4->GroupSharedLimit << "\n"; + } break; case PSVShaderKind::Library: case PSVShaderKind::Invalid: @@ -887,9 +911,10 @@ void DxilPipelineStateValidation::PrintPSVRuntimeInfo( PSVRuntimeInfo1 *pInfo1 = m_pPSVRuntimeInfo1; PSVRuntimeInfo2 *pInfo2 = m_pPSVRuntimeInfo2; PSVRuntimeInfo3 *pInfo3 = m_pPSVRuntimeInfo3; + PSVRuntimeInfo4 *pInfo4 = m_pPSVRuntimeInfo4; hlsl::PrintPSVRuntimeInfo( - OS, pInfo0, pInfo1, pInfo2, pInfo3, ShaderKind, + OS, pInfo0, pInfo1, pInfo2, pInfo3, pInfo4, ShaderKind, m_pPSVRuntimeInfo3 ? m_StringTable.Get(pInfo3->EntryFunctionName) : "", Comment); } diff --git a/lib/DxilValidation/DxilContainerValidation.cpp b/lib/DxilValidation/DxilContainerValidation.cpp index 89e23767fe..ec5e69aa0a 100644 --- a/lib/DxilValidation/DxilContainerValidation.cpp +++ b/lib/DxilValidation/DxilContainerValidation.cpp @@ -413,12 +413,13 @@ void PSVContentVerifier::VerifyEntryProperties(const ShaderModel *SM, PSVRuntimeInfo0 *PSV0, PSVRuntimeInfo1 *PSV1, PSVRuntimeInfo2 *PSV2) { - PSVRuntimeInfo3 DMPSV; - memset(&DMPSV, 0, sizeof(PSVRuntimeInfo3)); + PSVRuntimeInfo4 DMPSV; + memset(&DMPSV, 0, sizeof(PSVRuntimeInfo4)); hlsl::SetShaderProps((PSVRuntimeInfo0 *)&DMPSV, DM); hlsl::SetShaderProps((PSVRuntimeInfo1 *)&DMPSV, DM); hlsl::SetShaderProps((PSVRuntimeInfo2 *)&DMPSV, DM); + hlsl::SetShaderProps((PSVRuntimeInfo4 *)&DMPSV, DM); if (PSV1) { // Init things not set in InitPSVRuntimeInfo. DMPSV.ShaderStage = static_cast(SM->GetKind()); @@ -447,7 +448,7 @@ void PSVContentVerifier::VerifyEntryProperties(const ShaderModel *SM, if (Mismatched) { std::string Str; raw_string_ostream OS(Str); - hlsl::PrintPSVRuntimeInfo(OS, &DMPSV, &DMPSV, &DMPSV, &DMPSV, + hlsl::PrintPSVRuntimeInfo(OS, &DMPSV, &DMPSV, &DMPSV, &DMPSV, &DMPSV, static_cast(SM->GetKind()), DM.GetEntryFunctionName().c_str(), ""); OS.flush(); diff --git a/lib/DxilValidation/DxilValidation.cpp b/lib/DxilValidation/DxilValidation.cpp index 2ea6701581..3dc3ddc4c1 100644 --- a/lib/DxilValidation/DxilValidation.cpp +++ b/lib/DxilValidation/DxilValidation.cpp @@ -3921,6 +3921,18 @@ static void ValidateGlobalVariables(ValidationContext &ValCtx) { Rule = ValidationRule::SmMaxMSSMSize; MaxSize = DXIL::kMaxMSSMSize; } + + // Check if the entry function has attribute to override TGSM size. + if (M.HasDxilEntryProps(M.GetEntryFunction())) { + DxilEntryProps &EntryProps = M.GetDxilEntryProps(M.GetEntryFunction()); + if (EntryProps.props.IsCS()) { + unsigned SpecifiedTGSMSize = EntryProps.props.groupSharedLimitBytes; + if (SpecifiedTGSMSize > 0) { + MaxSize = SpecifiedTGSMSize; + } + } + } + if (TGSMSize > MaxSize) { Module::global_iterator GI = M.GetModule()->global_end(); GlobalVariable *GV = &*GI; diff --git a/tools/clang/include/clang/Basic/Attr.td b/tools/clang/include/clang/Basic/Attr.td index 83137dbc3a..33f1594cac 100644 --- a/tools/clang/include/clang/Basic/Attr.td +++ b/tools/clang/include/clang/Basic/Attr.td @@ -671,6 +671,11 @@ def HLSLNumThreads: InheritableAttr { let Args = [IntArgument<"X">, IntArgument<"Y">, IntArgument<"Z">]; let Documentation = [Undocumented]; } +def HLSLGroupSharedLimit: InheritableAttr { + let Spellings = [CXX11<"", "GroupSharedLimit", 2017>]; + let Args = [IntArgument<"Limit">]; + let Documentation = [Undocumented]; +} def HLSLRootSignature: InheritableAttr { let Spellings = [CXX11<"", "RootSignature", 2015>]; let Args = [StringArgument<"SignatureName">]; diff --git a/tools/clang/lib/CodeGen/CGHLSLMS.cpp b/tools/clang/lib/CodeGen/CGHLSLMS.cpp index 6c68381a20..ef456fe1c1 100644 --- a/tools/clang/lib/CodeGen/CGHLSLMS.cpp +++ b/tools/clang/lib/CodeGen/CGHLSLMS.cpp @@ -1646,6 +1646,36 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { } } + if (const HLSLGroupSharedLimitAttr *Attr = + FD->getAttr()) { + if (isEntry && !SM->IsCS() && !SM->IsMS() && !SM->IsAS()) { + unsigned DiagID = Diags.getCustomDiagID( + DiagnosticsEngine::Error, + "attribute GroupSharedLimit only valid for CS/MS/AS."); + Diags.Report(Attr->getLocation(), DiagID); + return; + } + + // Only valid for SM6.10+ + if (!SM->IsSM610Plus()) { + unsigned DiagID = Diags.getCustomDiagID( + DiagnosticsEngine::Error, "attribute GroupSharedLimit only valid for " + "Shader Model 6.10 and above."); + Diags.Report(Attr->getLocation(), DiagID); + return; + } + + funcProps->groupSharedLimitBytes = Attr->getLimit(); + } else { + if (SM->IsMS()) { // Fallback to default limits + funcProps->groupSharedLimitBytes = DXIL::kMaxMSSMSize; // 28k For MS + } else if (SM->IsAS() || SM->IsCS()) { + funcProps->groupSharedLimitBytes = DXIL::kMaxTGSMSize; // 32k For AS/CS + } else { + funcProps->groupSharedLimitBytes = 0; + } + } + // Hull shader. if (const HLSLPatchConstantFuncAttr *Attr = FD->getAttr()) { diff --git a/tools/clang/lib/Parse/ParseDecl.cpp b/tools/clang/lib/Parse/ParseDecl.cpp index 59be41a484..e8bea488a9 100644 --- a/tools/clang/lib/Parse/ParseDecl.cpp +++ b/tools/clang/lib/Parse/ParseDecl.cpp @@ -833,6 +833,7 @@ void Parser::ParseGNUAttributeArgs(IdentifierInfo *AttrName, case AttributeList::AT_HLSLMaxVertexCount: case AttributeList::AT_HLSLUnroll: case AttributeList::AT_HLSLWaveSize: + case AttributeList::AT_HLSLGroupSharedLimit: case AttributeList::AT_NoInline: // The following are not accepted in [attribute(param)] syntax: // case AttributeList::AT_HLSLCentroid: diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index e9c8c90a2d..0a0234d392 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -14656,6 +14656,11 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A, S.Context.getAddrSpaceQualType(VD->getType(), DXIL::kTGSMAddrSpace)); } break; + case AttributeList::AT_HLSLGroupSharedLimit: + declAttr = ::new (S.Context) HLSLGroupSharedLimitAttr( + A.getRange(), S.Context, ValidateAttributeIntArg(S, A), + A.getAttributeSpellingListIndex()); + break; case AttributeList::AT_HLSLUniform: declAttr = ::new (S.Context) HLSLUniformAttr( A.getRange(), S.Context, A.getAttributeSpellingListIndex()); diff --git a/tools/clang/test/HLSLFileCheck/hlsl/entry/attributes/GroupSharedLimit.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/entry/attributes/GroupSharedLimit.hlsl new file mode 100644 index 0000000000..cc91801263 --- /dev/null +++ b/tools/clang/test/HLSLFileCheck/hlsl/entry/attributes/GroupSharedLimit.hlsl @@ -0,0 +1,85 @@ +// RUN: %dxc -E MainPass -T cs_6_10 %s | FileCheck %s + +#define NUM_BYTES_OF_SHARED_MEM (32*1024) +#define NUM_DWORDS_SHARED_MEM (NUM_BYTES_OF_SHARED_MEM / 4) +#define THREAD_GROUP_SIZE_X 1024 + +groupshared uint g_testBufferPASS[NUM_DWORDS_SHARED_MEM]; + +RWStructuredBuffer g_output : register(u0); + +// CHECK: @MainPass + +[GroupSharedLimit(NUM_BYTES_OF_SHARED_MEM)] +[numthreads(THREAD_GROUP_SIZE_X, 1, 1)] +void MainPass( uint3 DTid : SV_DispatchThreadID ) +{ + uint iterations = NUM_DWORDS_SHARED_MEM / THREAD_GROUP_SIZE_X; + + for (uint i = 0; i < iterations; i++) + { + uint index = DTid.x + i * THREAD_GROUP_SIZE_X; + g_testBufferPASS[index] = index; + } + + // synchronize all threads in the group + GroupMemoryBarrierWithGroupSync(); + + // write the shared data to the output buffer + for (uint i = 0; i < iterations; i++) + { + uint index = DTid.x + i * THREAD_GROUP_SIZE_X; + g_output[index] = g_testBufferPASS[index]; + } +} + +// RUN: not %dxc -E MainFail -T cs_6_10 %s 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR +// CHECK-ERROR: Total Thread Group Shared Memory storage is 32772, exceeded 32768. + +groupshared uint g_testBufferFAIL[NUM_DWORDS_SHARED_MEM + 1]; + +[GroupSharedLimit(NUM_BYTES_OF_SHARED_MEM)] +[numthreads(THREAD_GROUP_SIZE_X, 1, 1)] +void MainFail( uint3 DTid : SV_DispatchThreadID ) +{ + uint iterations = NUM_DWORDS_SHARED_MEM / THREAD_GROUP_SIZE_X; + + for (uint i = 0; i < iterations; i++) + { + uint index = DTid.x + i * THREAD_GROUP_SIZE_X; + g_testBufferFAIL[index] = index; + } + + // synchronize all threads in the group + GroupMemoryBarrierWithGroupSync(); + + // write the shared data to the output buffer + for (uint i = 0; i < iterations; i++) + { + uint index = DTid.x + i * THREAD_GROUP_SIZE_X; + g_output[index] = g_testBufferFAIL[index]; + } +} + +// RUN: not %dxc -E MainFail2 -T cs_6_10 %s 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR +[numthreads(THREAD_GROUP_SIZE_X, 1, 1)] +void MainFail2( uint3 DTid : SV_DispatchThreadID ) +{ + uint iterations = NUM_DWORDS_SHARED_MEM / THREAD_GROUP_SIZE_X; + + for (uint i = 0; i < iterations; i++) + { + uint index = DTid.x + i * THREAD_GROUP_SIZE_X; + g_testBufferFAIL[index] = index; + } + + // synchronize all threads in the group + GroupMemoryBarrierWithGroupSync(); + + // write the shared data to the output buffer + for (uint i = 0; i < iterations; i++) + { + uint index = DTid.x + i * THREAD_GROUP_SIZE_X; + g_output[index] = g_testBufferFAIL[index]; + } +} \ No newline at end of file