@@ -37,6 +37,39 @@ using namespace llvm::dxil;
3737
3838namespace {
3939
40+ // / A simple wrapper of DiagnosticInfo that generates module-level diagnostic
41+ // / for the DXILValidateMetadata pass
42+ class DiagnosticInfoValidateMD : public DiagnosticInfo {
43+ private:
44+ const Twine &Msg;
45+ const Module &Mod;
46+
47+ public:
48+ // / \p M is the module for which the diagnostic is being emitted. \p Msg is
49+ // / the message to show. Note that this class does not copy this message, so
50+ // / this reference must be valid for the whole life time of the diagnostic.
51+ DiagnosticInfoValidateMD (const Module &M,
52+ const Twine &Msg LLVM_LIFETIME_BOUND,
53+ DiagnosticSeverity Severity = DS_Error)
54+ : DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {}
55+
56+ void print (DiagnosticPrinter &DP) const override {
57+ DP << Mod.getName () << " : " << Msg << ' \n ' ;
58+ }
59+ };
60+
61+ static bool reportError (Module &M, Twine Message,
62+ DiagnosticSeverity Severity = DS_Error) {
63+ M.getContext ().diagnose (DiagnosticInfoValidateMD (M, Message, Severity));
64+ return true ;
65+ }
66+
67+ static bool reportLoopError (Module &M, Twine Message,
68+ DiagnosticSeverity Severity = DS_Error) {
69+ return reportError (M, Twine (" Invalid \" llvm.loop\" metadata: " ) + Message,
70+ Severity);
71+ }
72+
4073enum class EntryPropsTag {
4174 ShaderFlags = 0 ,
4275 GSState,
@@ -294,6 +327,83 @@ static void translateBranchMetadata(Module &M, Instruction *BBTerminatorInst) {
294327 BBTerminatorInst->setMetadata (" hlsl.controlflow.hint" , nullptr );
295328}
296329
330+ static void translateLoopMetadata (Module &M, MDNode *LoopMD) {
331+ // DXIL only accepts the following loop hints:
332+ // llvm.loop.unroll.disable, llvm.loop.unroll.full, llvm.loop.unroll.count
333+ std::array<StringLiteral, 3 > ValidHintNames = {" llvm.loop.unroll.count" ,
334+ " llvm.loop.unroll.disable" ,
335+ " llvm.loop.unroll.full" };
336+
337+ // llvm.loop metadata must have its first operand be a self-reference, so we
338+ // require at least 1 operand.
339+ //
340+ // It only makes sense to specify up to 1 of the hints on a branch, so we can
341+ // have at most 2 operands.
342+
343+ if (LoopMD->getNumOperands () != 1 && LoopMD->getNumOperands () != 2 ) {
344+ reportLoopError (M, " Requires exactly 1 or 2 operands" );
345+ return ;
346+ }
347+
348+ if (LoopMD != LoopMD->getOperand (0 )) {
349+ reportLoopError (M, " First operand must be a self-reference" );
350+ return ;
351+ }
352+
353+ // A node only containing a self-reference is a valid use to denote a loop
354+ if (LoopMD->getNumOperands () == 1 )
355+ return ;
356+
357+ LoopMD = dyn_cast<MDNode>(LoopMD->getOperand (1 ));
358+ if (!LoopMD) {
359+ reportLoopError (M, " Second operand must be a metadata node" );
360+ return ;
361+ }
362+
363+ if (LoopMD->getNumOperands () != 1 && LoopMD->getNumOperands () != 2 ) {
364+ reportLoopError (M, " Requires exactly 1 or 2 operands" );
365+ return ;
366+ }
367+
368+ // It is valid to have a chain of self-referential loop metadata nodes so if
369+ // we have another self-reference, recurse.
370+ //
371+ // Eg:
372+ // !0 = !{!0, !1}
373+ // !1 = !{!1, !2}
374+ // !2 = !{"llvm.loop.unroll.disable"}
375+ if (LoopMD == LoopMD->getOperand (0 ))
376+ return translateLoopMetadata (M, LoopMD);
377+
378+ // Otherwise, we are at our base hint metadata node
379+ auto *HintStr = dyn_cast<MDString>(LoopMD->getOperand (0 ));
380+ if (!HintStr || !llvm::is_contained (ValidHintNames, HintStr->getString ())) {
381+ reportLoopError (M,
382+ " First operand must be a valid \" llvm.loop.unroll\" hint" );
383+ return ;
384+ }
385+
386+ // Ensure count node is a constant integer value
387+ auto ValidCountNode = [](MDNode *HintMD) -> bool {
388+ if (HintMD->getNumOperands () == 2 )
389+ if (auto *CountMD = dyn_cast<ConstantAsMetadata>(HintMD->getOperand (1 )))
390+ if (isa<ConstantInt>(CountMD->getValue ()))
391+ return true ;
392+ return false ;
393+ };
394+
395+ if (HintStr->getString () == " llvm.loop.unroll.count" ) {
396+ if (!ValidCountNode (LoopMD)) {
397+ reportLoopError (M, " Second operand of \" llvm.loop.unroll.count\" "
398+ " must be a constant integer" );
399+ return ;
400+ }
401+ } else if (LoopMD->getNumOperands () != 1 ) {
402+ reportLoopError (M, " Can't have a second operand" );
403+ return ;
404+ }
405+ }
406+
297407using InstructionMDList = std::array<unsigned , 7 >;
298408
299409static InstructionMDList getCompatibleInstructionMDs (llvm::Module &M) {
@@ -307,6 +417,7 @@ static InstructionMDList getCompatibleInstructionMDs(llvm::Module &M) {
307417static void translateInstructionMetadata (Module &M) {
308418 // construct allowlist of valid metadata node kinds
309419 InstructionMDList DXILCompatibleMDs = getCompatibleInstructionMDs (M);
420+ unsigned char MDLoopKind = M.getContext ().getMDKindID (" llvm.loop" );
310421
311422 for (Function &F : M) {
312423 for (BasicBlock &BB : F) {
@@ -316,6 +427,10 @@ static void translateInstructionMetadata(Module &M) {
316427 translateBranchMetadata (M, I);
317428
318429 for (auto &I : make_early_inc_range (BB)) {
430+ if (isa<BranchInst>(I)) {
431+ if (MDNode *LoopMD = I.getMetadata (MDLoopKind))
432+ translateLoopMetadata (M, LoopMD);
433+ }
319434 I.dropUnknownNonDebugMetadata (DXILCompatibleMDs);
320435 }
321436 }
@@ -372,17 +487,24 @@ static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM,
372487 uint64_t CombinedMask = ShaderFlags.getCombinedFlags ();
373488 EntryFnMDNodes.emplace_back (
374489 emitTopLevelLibraryNode (M, ResourceMD, CombinedMask));
375- }
490+ } else if (1 < MMDI.EntryPropertyVec .size ())
491+ reportError (M, " Non-library shader: One and only one entry expected" );
376492
377493 for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec ) {
378- const ComputedShaderFlags &EntrySFMask =
379- ShaderFlags.getFunctionFlags (EntryProp.Entry );
380-
381- // If ShaderProfile is Library, mask is already consolidated in the
382- // top-level library node. Hence it is not emitted.
383- uint64_t EntryShaderFlags =
384- MMDI.ShaderProfile == Triple::EnvironmentType::Library ? 0
385- : EntrySFMask;
494+ uint64_t EntryShaderFlags = 0 ;
495+ if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
496+ EntryShaderFlags = ShaderFlags.getFunctionFlags (EntryProp.Entry );
497+ if (EntryProp.ShaderStage != MMDI.ShaderProfile )
498+ reportError (
499+ M,
500+ " Shader stage '" +
501+ Twine (Twine (getShortShaderStage (EntryProp.ShaderStage )) +
502+ " ' for entry '" + Twine (EntryProp.Entry ->getName ()) +
503+ " ' different from specified target profile '" +
504+ Twine (Triple::getEnvironmentTypeName (MMDI.ShaderProfile ) +
505+ " '" )));
506+ }
507+
386508 EntryFnMDNodes.emplace_back (emitEntryMD (EntryProp, Signatures, ResourceMD,
387509 EntryShaderFlags,
388510 MMDI.ShaderProfile ));
0 commit comments