Skip to content

Commit 24909ab

Browse files
committed
self-review: merge DXILValidate into DXILTranslate
it seemed like we could create this nice level of abstraction, however, this will just cause us to write duplicate logic for iterating through the metadata to first transform and then validate. So we may as well transform in place This problem arises as we want to strip certain llvm.loop metadata and validate on the other types
1 parent f027a93 commit 24909ab

File tree

12 files changed

+145
-276
lines changed

12 files changed

+145
-276
lines changed

llvm/lib/Target/DirectX/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ add_llvm_target(DirectXCodeGen
3535
DXILResourceImplicitBinding.cpp
3636
DXILShaderFlags.cpp
3737
DXILTranslateMetadata.cpp
38-
DXILValidateMetadata.cpp
3938
DXILRootSignature.cpp
4039
DXILLegalizePass.cpp
4140

llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp

Lines changed: 131 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,39 @@ using namespace llvm::dxil;
3737

3838
namespace {
3939

40+
/// A simple wrapper of DiagnosticInfo that generates module-level diagnostic
41+
/// for the DXILValidateMetadata pass
42+
class DiagnosticInfoValidateMD : public DiagnosticInfo {
43+
private:
44+
const Twine &Msg;
45+
const Module &Mod;
46+
47+
public:
48+
/// \p M is the module for which the diagnostic is being emitted. \p Msg is
49+
/// the message to show. Note that this class does not copy this message, so
50+
/// this reference must be valid for the whole life time of the diagnostic.
51+
DiagnosticInfoValidateMD(const Module &M,
52+
const Twine &Msg LLVM_LIFETIME_BOUND,
53+
DiagnosticSeverity Severity = DS_Error)
54+
: DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {}
55+
56+
void print(DiagnosticPrinter &DP) const override {
57+
DP << Mod.getName() << ": " << Msg << '\n';
58+
}
59+
};
60+
61+
static bool reportError(Module &M, Twine Message,
62+
DiagnosticSeverity Severity = DS_Error) {
63+
M.getContext().diagnose(DiagnosticInfoValidateMD(M, Message, Severity));
64+
return true;
65+
}
66+
67+
static bool reportLoopError(Module &M, Twine Message,
68+
DiagnosticSeverity Severity = DS_Error) {
69+
return reportError(M, Twine("Invalid \"llvm.loop\" metadata: ") + Message,
70+
Severity);
71+
}
72+
4073
enum class EntryPropsTag {
4174
ShaderFlags = 0,
4275
GSState,
@@ -294,6 +327,83 @@ static void translateBranchMetadata(Module &M, Instruction *BBTerminatorInst) {
294327
BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr);
295328
}
296329

330+
static void translateLoopMetadata(Module &M, MDNode *LoopMD) {
331+
// DXIL only accepts the following loop hints:
332+
// llvm.loop.unroll.disable, llvm.loop.unroll.full, llvm.loop.unroll.count
333+
std::array<StringLiteral, 3> ValidHintNames = {"llvm.loop.unroll.count",
334+
"llvm.loop.unroll.disable",
335+
"llvm.loop.unroll.full"};
336+
337+
// llvm.loop metadata must have its first operand be a self-reference, so we
338+
// require at least 1 operand.
339+
//
340+
// It only makes sense to specify up to 1 of the hints on a branch, so we can
341+
// have at most 2 operands.
342+
343+
if (LoopMD->getNumOperands() != 1 && LoopMD->getNumOperands() != 2) {
344+
reportLoopError(M, "Requires exactly 1 or 2 operands");
345+
return;
346+
}
347+
348+
if (LoopMD != LoopMD->getOperand(0)) {
349+
reportLoopError(M, "First operand must be a self-reference");
350+
return;
351+
}
352+
353+
// A node only containing a self-reference is a valid use to denote a loop
354+
if (LoopMD->getNumOperands() == 1)
355+
return;
356+
357+
LoopMD = dyn_cast<MDNode>(LoopMD->getOperand(1));
358+
if (!LoopMD) {
359+
reportLoopError(M, "Second operand must be a metadata node");
360+
return;
361+
}
362+
363+
if (LoopMD->getNumOperands() != 1 && LoopMD->getNumOperands() != 2) {
364+
reportLoopError(M, "Requires exactly 1 or 2 operands");
365+
return;
366+
}
367+
368+
// It is valid to have a chain of self-referential loop metadata nodes so if
369+
// we have another self-reference, recurse.
370+
//
371+
// Eg:
372+
// !0 = !{!0, !1}
373+
// !1 = !{!1, !2}
374+
// !2 = !{"llvm.loop.unroll.disable"}
375+
if (LoopMD == LoopMD->getOperand(0))
376+
return translateLoopMetadata(M, LoopMD);
377+
378+
// Otherwise, we are at our base hint metadata node
379+
auto *HintStr = dyn_cast<MDString>(LoopMD->getOperand(0));
380+
if (!HintStr || !llvm::is_contained(ValidHintNames, HintStr->getString())) {
381+
reportLoopError(M,
382+
"First operand must be a valid \"llvm.loop.unroll\" hint");
383+
return;
384+
}
385+
386+
// Ensure count node is a constant integer value
387+
auto ValidCountNode = [](MDNode *HintMD) -> bool {
388+
if (HintMD->getNumOperands() == 2)
389+
if (auto *CountMD = dyn_cast<ConstantAsMetadata>(HintMD->getOperand(1)))
390+
if (isa<ConstantInt>(CountMD->getValue()))
391+
return true;
392+
return false;
393+
};
394+
395+
if (HintStr->getString() == "llvm.loop.unroll.count") {
396+
if (!ValidCountNode(LoopMD)) {
397+
reportLoopError(M, "Second operand of \"llvm.loop.unroll.count\" "
398+
"must be a constant integer");
399+
return;
400+
}
401+
} else if (LoopMD->getNumOperands() != 1) {
402+
reportLoopError(M, "Can't have a second operand");
403+
return;
404+
}
405+
}
406+
297407
using InstructionMDList = std::array<unsigned, 7>;
298408

299409
static InstructionMDList getCompatibleInstructionMDs(llvm::Module &M) {
@@ -307,6 +417,7 @@ static InstructionMDList getCompatibleInstructionMDs(llvm::Module &M) {
307417
static void translateInstructionMetadata(Module &M) {
308418
// construct allowlist of valid metadata node kinds
309419
InstructionMDList DXILCompatibleMDs = getCompatibleInstructionMDs(M);
420+
unsigned char MDLoopKind = M.getContext().getMDKindID("llvm.loop");
310421

311422
for (Function &F : M) {
312423
for (BasicBlock &BB : F) {
@@ -316,6 +427,10 @@ static void translateInstructionMetadata(Module &M) {
316427
translateBranchMetadata(M, I);
317428

318429
for (auto &I : make_early_inc_range(BB)) {
430+
if (isa<BranchInst>(I)) {
431+
if (MDNode *LoopMD = I.getMetadata(MDLoopKind))
432+
translateLoopMetadata(M, LoopMD);
433+
}
319434
I.dropUnknownNonDebugMetadata(DXILCompatibleMDs);
320435
}
321436
}
@@ -372,17 +487,24 @@ static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM,
372487
uint64_t CombinedMask = ShaderFlags.getCombinedFlags();
373488
EntryFnMDNodes.emplace_back(
374489
emitTopLevelLibraryNode(M, ResourceMD, CombinedMask));
375-
}
490+
} else if (1 < MMDI.EntryPropertyVec.size())
491+
reportError(M, "Non-library shader: One and only one entry expected");
376492

377493
for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) {
378-
const ComputedShaderFlags &EntrySFMask =
379-
ShaderFlags.getFunctionFlags(EntryProp.Entry);
380-
381-
// If ShaderProfile is Library, mask is already consolidated in the
382-
// top-level library node. Hence it is not emitted.
383-
uint64_t EntryShaderFlags =
384-
MMDI.ShaderProfile == Triple::EnvironmentType::Library ? 0
385-
: EntrySFMask;
494+
uint64_t EntryShaderFlags = 0;
495+
if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
496+
EntryShaderFlags = ShaderFlags.getFunctionFlags(EntryProp.Entry);
497+
if (EntryProp.ShaderStage != MMDI.ShaderProfile)
498+
reportError(
499+
M,
500+
"Shader stage '" +
501+
Twine(Twine(getShortShaderStage(EntryProp.ShaderStage)) +
502+
"' for entry '" + Twine(EntryProp.Entry->getName()) +
503+
"' different from specified target profile '" +
504+
Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) +
505+
"'")));
506+
}
507+
386508
EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD,
387509
EntryShaderFlags,
388510
MMDI.ShaderProfile));

0 commit comments

Comments
 (0)