Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/dxc/DXIL/DxilFunctionProps.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ struct DxilFunctionProps {
memset(&Node, 0, sizeof(Node));
Node.LaunchType = DXIL::NodeLaunchType::Invalid;
Node.LocalRootArgumentsTableIndex = -1;
groupSharedLimitBytes = 0;
}
union {
// Geometry shader.
Expand Down Expand Up @@ -174,6 +175,8 @@ struct DxilFunctionProps {
// numThreads shared between multiple shader types and node shaders.
unsigned numThreads[3];

unsigned groupSharedLimitBytes;

struct NodeProps {
DXIL::NodeLaunchType LaunchType = DXIL::NodeLaunchType::Invalid;
bool IsProgramEntry;
Expand Down
1 change: 1 addition & 0 deletions include/dxc/DXIL/DxilMetadataHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ class DxilMDHelper {
static const unsigned kDxilNodeOutputsTag = 21;
static const unsigned kDxilNodeMaxDispatchGridTag = 22;
static const unsigned kDxilRangedWaveSizeTag = 23;
static const unsigned kDxilMaxGroupSharedMemTag = 24;

// Node Input/Output State.
static const unsigned kDxilNodeOutputIDTag = 0;
Expand Down
2 changes: 2 additions & 0 deletions include/dxc/DXIL/DxilModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ class DxilModule {
void SetNumThreads(unsigned x, unsigned y, unsigned z);
unsigned GetNumThreads(unsigned idx) const;

unsigned GetGroupSharedLimit() const;

// Compute shader
DxilWaveSize &GetWaveSize();
const DxilWaveSize &GetWaveSize() const;
Expand Down
23 changes: 18 additions & 5 deletions include/dxc/DxilContainer/DxilPipelineStateValidation.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ struct PSVRuntimeInfo3 : public PSVRuntimeInfo2 {
uint32_t EntryFunctionName;
};

struct PSVRuntimeInfo4 : public PSVRuntimeInfo3 {
uint32_t GroupSharedMemoryLimit;
};

enum class PSVResourceType {
Invalid = 0,

Expand Down Expand Up @@ -474,7 +478,7 @@ class PSVSignatureElement {
const uint32_t *SemanticIndexes) const;
};

#define MAX_PSV_VERSION 3
#define MAX_PSV_VERSION 4

struct PSVInitInfo {
PSVInitInfo(uint32_t psvVersion) : PSVVersion(psvVersion) {}
Expand All @@ -491,7 +495,7 @@ struct PSVInitInfo {
uint8_t SigPatchConstOrPrimVectors = 0;
uint8_t SigOutputVectors[PSV_GS_MAX_STREAMS] = {0, 0, 0, 0};

static_assert(MAX_PSV_VERSION == 3, "otherwise this needs updating.");
static_assert(MAX_PSV_VERSION == 4, "otherwise this needs updating.");
uint32_t RuntimeInfoSize() const {
switch (PSVVersion) {
case 0:
Expand All @@ -500,10 +504,12 @@ struct PSVInitInfo {
return sizeof(PSVRuntimeInfo1);
case 2:
return sizeof(PSVRuntimeInfo2);
case 3:
return sizeof(PSVRuntimeInfo3);
default:
break;
}
return sizeof(PSVRuntimeInfo3);
return sizeof(PSVRuntimeInfo4);
}
uint32_t ResourceBindInfoSize() const {
if (PSVVersion < 2)
Expand All @@ -519,6 +525,7 @@ class DxilPipelineStateValidation {
PSVRuntimeInfo1 *m_pPSVRuntimeInfo1 = nullptr;
PSVRuntimeInfo2 *m_pPSVRuntimeInfo2 = nullptr;
PSVRuntimeInfo3 *m_pPSVRuntimeInfo3 = nullptr;
PSVRuntimeInfo4 *m_pPSVRuntimeInfo4 = nullptr;
uint32_t m_uResourceCount = 0;
uint32_t m_uPSVResourceBindInfoSize = 0;
void *m_pPSVResourceBindInfo = nullptr;
Expand Down Expand Up @@ -634,6 +641,8 @@ class DxilPipelineStateValidation {

PSVRuntimeInfo3 *GetPSVRuntimeInfo3() const { return m_pPSVRuntimeInfo3; }

PSVRuntimeInfo4 *GetPSVRuntimeInfo4() const { return m_pPSVRuntimeInfo4; }

uint32_t GetBindCount() const { return m_uResourceCount; }

template <typename _T>
Expand Down Expand Up @@ -949,6 +958,8 @@ DxilPipelineStateValidation::ReadOrWrite(const void *pBits, uint32_t *pSize,
m_uPSVRuntimeInfoSize); // failure ok
AssignDerived(&m_pPSVRuntimeInfo3, m_pPSVRuntimeInfo0,
m_uPSVRuntimeInfoSize); // failure ok
AssignDerived(&m_pPSVRuntimeInfo4, m_pPSVRuntimeInfo0,
m_uPSVRuntimeInfoSize); // failure ok

// In RWMode::CalcSize, use temp runtime info to hold needed values from
// initInfo
Expand Down Expand Up @@ -1137,11 +1148,13 @@ void SetupPSVInitInfo(PSVInitInfo &InitInfo, const DxilModule &DM);
void SetShaderProps(PSVRuntimeInfo0 *pInfo, const DxilModule &DM);
void SetShaderProps(PSVRuntimeInfo1 *pInfo1, const DxilModule &DM);
void SetShaderProps(PSVRuntimeInfo2 *pInfo2, const DxilModule &DM);
void SetShaderProps(PSVRuntimeInfo4 *pInfo4, const DxilModule &DM);

void PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0,
PSVRuntimeInfo1 *pInfo1, PSVRuntimeInfo2 *pInfo2,
PSVRuntimeInfo3 *pInfo3, uint8_t ShaderKind,
const char *EntryName, const char *Comment);
PSVRuntimeInfo3 *pInfo3, PSVRuntimeInfo4 *pInfo4,
uint8_t ShaderKind, const char *EntryName,
const char *Comment);

} // namespace hlsl

Expand Down
9 changes: 9 additions & 0 deletions lib/DXIL/DxilMetadataHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1624,6 +1624,10 @@ MDTuple *DxilMDHelper::EmitDxilEntryProperties(uint64_t rawShaderFlag,
}
MDVals.emplace_back(MDNode::get(m_Ctx, WaveSizeVal));
}

MDVals.emplace_back(Uint32ToConstMD(DxilMDHelper::kDxilMaxGroupSharedMemTag));
MDVals.emplace_back(
Uint32ToConstMD(props.groupSharedLimitBytes));
} break;
// Geometry shader.
case DXIL::ShaderKind::Geometry: {
Expand Down Expand Up @@ -1773,6 +1777,11 @@ void DxilMDHelper::LoadDxilEntryProperties(const MDOperand &MDO,
props.numThreads[2] = ConstMDToUint32(pNode->getOperand(2));
} break;

case DxilMDHelper::kDxilMaxGroupSharedMemTag: {
DXASSERT(props.IsCS(), "else invalid shader kind");
props.groupSharedLimitBytes = ConstMDToUint32(MDO);
} break;

case DxilMDHelper::kDxilGSStateTag: {
DXASSERT(props.IsGS(), "else invalid shader kind");
auto &GS = props.ShaderProps.GS;
Expand Down
10 changes: 10 additions & 0 deletions lib/DXIL/DxilModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,16 @@ unsigned DxilModule::GetNumThreads(unsigned idx) const {
return props.numThreads[idx];
}

unsigned DxilModule::GetGroupSharedLimit() const {
DXASSERT(m_DxilEntryPropsMap.size() == 1 &&
(m_pSM->IsCS() || m_pSM->IsMS() || m_pSM->IsAS()),
"only works for CS/MS/AS profiles");
const DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
DXASSERT_NOMSG(m_pSM->GetKind() == props.shaderKind);
return props.groupSharedLimitBytes;
}


DxilWaveSize &DxilModule::GetWaveSize() {
return const_cast<DxilWaveSize &>(
static_cast<const DxilModule *>(this)->GetWaveSize());
Expand Down
4 changes: 4 additions & 0 deletions lib/DxilContainer/DxilContainerAssembler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,8 @@ class DxilPSVWriter : public DxilPartWriter {
PSVRuntimeInfo1 *pInfo1 = m_PSV.GetPSVRuntimeInfo1();
PSVRuntimeInfo2 *pInfo2 = m_PSV.GetPSVRuntimeInfo2();
PSVRuntimeInfo3 *pInfo3 = m_PSV.GetPSVRuntimeInfo3();
PSVRuntimeInfo4 *pInfo4 = m_PSV.GetPSVRuntimeInfo4();

if (pInfo)
hlsl::SetShaderProps(pInfo, m_Module);
if (pInfo1)
Expand All @@ -806,6 +808,8 @@ class DxilPSVWriter : public DxilPartWriter {
hlsl::SetShaderProps(pInfo2, m_Module);
if (pInfo3)
pInfo3->EntryFunctionName = EntryFunctionName;
if (pInfo4)
hlsl::SetShaderProps(pInfo4, m_Module);

// Set resource binding information
UINT uResIndex = 0;
Expand Down
35 changes: 32 additions & 3 deletions lib/DxilContainer/DxilPipelineStateValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,21 @@ void hlsl::SetShaderProps(PSVRuntimeInfo2 *pInfo2, const DxilModule &DM) {
}
}

void hlsl::SetShaderProps(PSVRuntimeInfo4 *pInfo4, const DxilModule &DM) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
void hlsl::SetShaderProps(PSVRuntimeInfo4 *pInfo4, const DxilModule &DM) {
void hlsl::SetShaderProps(PSVRuntimeInfo4 *Info4, const DxilModule &DM) {

For new code we generally follow the LLVM Coding Standards. This includes no-p-prefix.

assert(pInfo4);
const ShaderModel* SM = DM.GetShaderModel();
switch (SM->GetKind())
{
case ShaderModel::Kind::Compute:
case ShaderModel::Kind::Mesh:
case ShaderModel::Kind::Amplification:
pInfo4->GroupSharedMemoryLimit = DM.GetGroupSharedLimit();
break;
default:
break;
}
}

void PSVResourceBindInfo0::Print(raw_ostream &OS) const {
OS << "PSVResourceBindInfo:\n";
OS << " Space: " << Space << "\n";
Expand Down Expand Up @@ -584,8 +599,9 @@ void PSVDependencyTable::Print(raw_ostream &OS, const char *InputSetName,

void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0,
PSVRuntimeInfo1 *pInfo1, PSVRuntimeInfo2 *pInfo2,
PSVRuntimeInfo3 *pInfo3, uint8_t ShaderKind,
const char *EntryName, const char *Comment) {
PSVRuntimeInfo3 *pInfo3, PSVRuntimeInfo4 *pInfo4,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO in this case following the pInfoN pattern is appropriate.

uint8_t ShaderKind, const char *EntryName,
const char *Comment) {
if (pInfo1 && pInfo1->ShaderStage != ShaderKind)
ShaderKind = pInfo1->ShaderStage;
OS << Comment << "PSVRuntimeInfo:\n";
Expand Down Expand Up @@ -808,13 +824,21 @@ void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0,
OS << Comment << " NumThreads=(" << pInfo2->NumThreadsX << ","
<< pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n";
}
if (pInfo4) {
OS << Comment << " GroupSharedMemoryLimit="
<< pInfo4->GroupSharedMemoryLimit << "\n";
}
break;
case PSVShaderKind::Amplification:
OS << Comment << " Amplification Shader\n";
if (pInfo2) {
OS << Comment << " NumThreads=(" << pInfo2->NumThreadsX << ","
<< pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n";
}
if (pInfo4) {
OS << Comment << " GroupSharedMemoryLimit="
<< pInfo4->GroupSharedMemoryLimit << "\n";
}
break;
case PSVShaderKind::Mesh:
OS << Comment << " Mesh Shader\n";
Expand All @@ -841,6 +865,10 @@ void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0,
OS << Comment << " NumThreads=(" << pInfo2->NumThreadsX << ","
<< pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n";
}
if (pInfo4) {
OS << Comment << " GroupSharedMemoryLimit="
<< pInfo4->GroupSharedMemoryLimit << "\n";
}
break;
case PSVShaderKind::Library:
case PSVShaderKind::Invalid:
Expand Down Expand Up @@ -887,9 +915,10 @@ void DxilPipelineStateValidation::PrintPSVRuntimeInfo(
PSVRuntimeInfo1 *pInfo1 = m_pPSVRuntimeInfo1;
PSVRuntimeInfo2 *pInfo2 = m_pPSVRuntimeInfo2;
PSVRuntimeInfo3 *pInfo3 = m_pPSVRuntimeInfo3;
PSVRuntimeInfo4 *pInfo4 = m_pPSVRuntimeInfo4;

hlsl::PrintPSVRuntimeInfo(
OS, pInfo0, pInfo1, pInfo2, pInfo3, ShaderKind,
OS, pInfo0, pInfo1, pInfo2, pInfo3, pInfo4, ShaderKind,
m_pPSVRuntimeInfo3 ? m_StringTable.Get(pInfo3->EntryFunctionName) : "",
Comment);
}
Expand Down
7 changes: 4 additions & 3 deletions lib/DxilValidation/DxilContainerValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,12 +413,13 @@ void PSVContentVerifier::VerifyEntryProperties(const ShaderModel *SM,
PSVRuntimeInfo0 *PSV0,
PSVRuntimeInfo1 *PSV1,
PSVRuntimeInfo2 *PSV2) {
PSVRuntimeInfo3 DMPSV;
memset(&DMPSV, 0, sizeof(PSVRuntimeInfo3));
PSVRuntimeInfo4 DMPSV;
memset(&DMPSV, 0, sizeof(PSVRuntimeInfo4));

hlsl::SetShaderProps((PSVRuntimeInfo0 *)&DMPSV, DM);
hlsl::SetShaderProps((PSVRuntimeInfo1 *)&DMPSV, DM);
hlsl::SetShaderProps((PSVRuntimeInfo2 *)&DMPSV, DM);
hlsl::SetShaderProps((PSVRuntimeInfo4 *)&DMPSV, DM);
if (PSV1) {
// Init things not set in InitPSVRuntimeInfo.
DMPSV.ShaderStage = static_cast<uint8_t>(SM->GetKind());
Expand Down Expand Up @@ -447,7 +448,7 @@ void PSVContentVerifier::VerifyEntryProperties(const ShaderModel *SM,
if (Mismatched) {
std::string Str;
raw_string_ostream OS(Str);
hlsl::PrintPSVRuntimeInfo(OS, &DMPSV, &DMPSV, &DMPSV, &DMPSV,
hlsl::PrintPSVRuntimeInfo(OS, &DMPSV, &DMPSV, &DMPSV, &DMPSV, &DMPSV,
static_cast<uint8_t>(SM->GetKind()),
DM.GetEntryFunctionName().c_str(), "");
OS.flush();
Expand Down
12 changes: 12 additions & 0 deletions lib/DxilValidation/DxilValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3921,6 +3921,18 @@ static void ValidateGlobalVariables(ValidationContext &ValCtx) {
Rule = ValidationRule::SmMaxMSSMSize;
MaxSize = DXIL::kMaxMSSMSize;
}

// Check if the entry function has attribute to override TGSM size.
if (M.HasDxilEntryProps(M.GetEntryFunction())) {
DxilEntryProps &EntryProps = M.GetDxilEntryProps(M.GetEntryFunction());
if (EntryProps.props.IsCS()) {
unsigned SpecifiedTGSMSize = EntryProps.props.groupSharedLimitBytes;
if (SpecifiedTGSMSize > 0) {
MaxSize = SpecifiedTGSMSize;
}
Comment on lines 3930 to 3932
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (SpecifiedTGSMSize > 0) {
MaxSize = SpecifiedTGSMSize;
}
if (SpecifiedTGSMSize > 0)
MaxSize = SpecifiedTGSMSize;

LLVM coding standards say to omit braces here.

Something's also up with the formatting. Did the format-checker spot it? Anyway, clang-format should fix this for you.

}
}

if (TGSMSize > MaxSize) {
Module::global_iterator GI = M.GetModule()->global_end();
GlobalVariable *GV = &*GI;
Expand Down
5 changes: 5 additions & 0 deletions tools/clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,11 @@ def HLSLNumThreads: InheritableAttr {
let Args = [IntArgument<"X">, IntArgument<"Y">, IntArgument<"Z">];
let Documentation = [Undocumented];
}
def HLSLGroupSharedLimit: InheritableAttr {
let Spellings = [CXX11<"", "GroupSharedLimit", 2017>];
let Args = [IntArgument<"Limit">];
let Documentation = [Undocumented];
}
def HLSLRootSignature: InheritableAttr {
let Spellings = [CXX11<"", "RootSignature", 2015>];
let Args = [StringArgument<"SignatureName">];
Expand Down
15 changes: 15 additions & 0 deletions tools/clang/lib/CodeGen/CGHLSLMS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1646,6 +1646,21 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
}
}

if (const HLSLGroupSharedLimitAttr *Attr = FD->getAttr<HLSLGroupSharedLimitAttr>()) {
funcProps->groupSharedLimitBytes = Attr->getLimit();

if (isEntry && !SM->IsCS() && !SM->IsMS() && !SM->IsAS()) {
unsigned DiagID = Diags.getCustomDiagID(
DiagnosticsEngine::Error,
"attribute GroupSharedLimit only valid for CS/MS/AS.");
Diags.Report(Attr->getLocation(), DiagID);
return;
}

} else {
funcProps->groupSharedLimitBytes = DXIL::kMaxTGSMSize;
}

// Hull shader.
if (const HLSLPatchConstantFuncAttr *Attr =
FD->getAttr<HLSLPatchConstantFuncAttr>()) {
Expand Down
1 change: 1 addition & 0 deletions tools/clang/lib/Parse/ParseDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,7 @@ void Parser::ParseGNUAttributeArgs(IdentifierInfo *AttrName,
case AttributeList::AT_HLSLMaxVertexCount:
case AttributeList::AT_HLSLUnroll:
case AttributeList::AT_HLSLWaveSize:
case AttributeList::AT_HLSLGroupSharedLimit:
case AttributeList::AT_NoInline:
// The following are not accepted in [attribute(param)] syntax:
// case AttributeList::AT_HLSLCentroid:
Expand Down
5 changes: 5 additions & 0 deletions tools/clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14656,6 +14656,11 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A,
S.Context.getAddrSpaceQualType(VD->getType(), DXIL::kTGSMAddrSpace));
}
break;
case AttributeList::AT_HLSLGroupSharedLimit:
declAttr = ::new (S.Context) HLSLGroupSharedLimitAttr(
A.getRange(), S.Context, ValidateAttributeIntArg(S, A),
A.getAttributeSpellingListIndex());
break;
case AttributeList::AT_HLSLUniform:
declAttr = ::new (S.Context) HLSLUniformAttr(
A.getRange(), S.Context, A.getAttributeSpellingListIndex());
Expand Down
Loading