@@ -327,81 +327,90 @@ static void translateBranchMetadata(Module &M, Instruction *BBTerminatorInst) {
327327 BBTerminatorInst->setMetadata (" hlsl.controlflow.hint" , nullptr );
328328}
329329
330- static void translateLoopMetadata (Module &M, MDNode *LoopMD) {
330+ // Determines if the metadata node will be compatible with DXIL's loop metadata
331+ // representation.
332+ //
333+ // Reports an error for compatible metadata that is ill-formed.
334+ static bool isLoopMDCompatible (Module &M, Metadata *MD) {
331335 // DXIL only accepts the following loop hints:
332- // llvm.loop.unroll.disable, llvm.loop.unroll.full, llvm.loop.unroll.count
333336 std::array<StringLiteral, 3 > ValidHintNames = {" llvm.loop.unroll.count" ,
334337 " llvm.loop.unroll.disable" ,
335338 " llvm.loop.unroll.full" };
336339
337- // llvm.loop metadata must have its first operand be a self-reference, so we
338- // require at least 1 operand.
339- //
340- // It only makes sense to specify up to 1 of the hints on a branch, so we can
341- // have at most 2 operands.
340+ MDNode *HintMD = dyn_cast<MDNode>(MD);
341+ if (!HintMD || HintMD->getNumOperands () == 0 )
342+ return false ;
342343
343- if (LoopMD->getNumOperands () != 1 && LoopMD->getNumOperands () != 2 ) {
344- reportLoopError (M, " Requires exactly 1 or 2 operands" );
345- return ;
346- }
344+ auto *HintStr = dyn_cast<MDString>(HintMD->getOperand (0 ));
345+ if (!HintStr)
346+ return false ;
347347
348- if (LoopMD != LoopMD->getOperand (0 )) {
349- reportLoopError (M, " First operand must be a self-reference" );
350- return ;
351- }
348+ if (!llvm::is_contained (ValidHintNames, HintStr->getString ()))
349+ return false ;
352350
353- // A node only containing a self-reference is a valid use to denote a loop
354- if (LoopMD->getNumOperands () == 1 )
355- return ;
351+ auto ValidCountNode = [](MDNode *CountMD) -> bool {
352+ if (CountMD->getNumOperands () == 2 )
353+ if (auto *Count = dyn_cast<ConstantAsMetadata>(CountMD->getOperand (1 )))
354+ if (isa<ConstantInt>(Count->getValue ()))
355+ return true ;
356+ return false ;
357+ };
356358
357- LoopMD = dyn_cast<MDNode>(LoopMD->getOperand (1 ));
358- if (!LoopMD) {
359- reportLoopError (M, " Second operand must be a metadata node" );
360- return ;
361- }
359+ if (HintStr->getString () == " llvm.loop.unroll.count" ) {
360+ if (!ValidCountNode (HintMD))
361+ return reportLoopError (M, " Second operand of \" llvm.loop.unroll.count\" "
362+ " must be a constant integer" );
363+ } else if (HintMD->getNumOperands () != 1 )
364+ return reportLoopError (
365+ M, " \" llvm.loop.unroll.disable\" and \" llvm.loop.unroll.disable\" "
366+ " must be provided as a single operand" );
362367
363- if (LoopMD->getNumOperands () != 1 && LoopMD->getNumOperands () != 2 ) {
364- reportLoopError (M, " Requires exactly 1 or 2 operands" );
365- return ;
366- }
368+ return true ;
369+ }
370+
371+ static void translateLoopMetadata (Module &M, Instruction *I, MDNode *BaseMD) {
372+ // A distinct node has the self-referential form: !0 = !{ !0, ... }
373+ auto IsDistinctNode = [](MDNode *Node) -> bool {
374+ return Node && Node->getNumOperands () != 0 && Node == Node->getOperand (0 );
375+ };
376+
377+ // Strip empty metadata or a non-distinct node
378+ if (BaseMD->getNumOperands () == 0 || !IsDistinctNode (BaseMD))
379+ return I->setMetadata (" llvm.loop" , nullptr );
367380
368- // It is valid to have a chain of self-referential loop metadata nodes so if
369- // we have another self-reference, recurse.
381+ // It is valid to have a chain of self-refential loop metadata nodes, as
382+ // below. We will collapse these into just one when we reconstruct the
383+ // metadata.
370384 //
371385 // Eg:
372386 // !0 = !{!0, !1}
373387 // !1 = !{!1, !2}
374- // !2 = !{"llvm.loop.unroll.disable"}
375- if (LoopMD == LoopMD->getOperand (0 ))
376- return translateLoopMetadata (M, LoopMD);
377-
378- // Otherwise, we are at our base hint metadata node
379- auto *HintStr = dyn_cast<MDString>(LoopMD->getOperand (0 ));
380- if (!HintStr || !llvm::is_contained (ValidHintNames, HintStr->getString ())) {
381- reportLoopError (M,
382- " First operand must be a valid \" llvm.loop.unroll\" hint" );
383- return ;
384- }
385-
386- // Ensure count node is a constant integer value
387- auto ValidCountNode = [](MDNode *HintMD) -> bool {
388- if (HintMD->getNumOperands () == 2 )
389- if (auto *CountMD = dyn_cast<ConstantAsMetadata>(HintMD->getOperand (1 )))
390- if (isa<ConstantInt>(CountMD->getValue ()))
391- return true ;
392- return false ;
393- };
394-
395- if (HintStr->getString () == " llvm.loop.unroll.count" ) {
396- if (!ValidCountNode (LoopMD)) {
397- reportLoopError (M, " Second operand of \" llvm.loop.unroll.count\" "
398- " must be a constant integer" );
399- return ;
400- }
401- } else if (LoopMD->getNumOperands () != 1 ) {
402- reportLoopError (M, " Can't have a second operand" );
403- return ;
404- }
388+ // !2 = !{!"llvm.loop.unroll.disable"}
389+ //
390+ // So, traverse down a potential self-referential chain
391+ while (1 < BaseMD->getNumOperands () &&
392+ IsDistinctNode (dyn_cast<MDNode>(BaseMD->getOperand (1 ))))
393+ BaseMD = dyn_cast<MDNode>(BaseMD->getOperand (1 ));
394+
395+ // To reconstruct a distinct node we create a temporary node that we will
396+ // then update to create a self-reference.
397+ llvm::TempMDTuple TempNode = llvm::MDNode::getTemporary (M.getContext (), {});
398+ SmallVector<Metadata *> CompatibleOperands = {TempNode.get ()};
399+
400+ // Iterate and reconstruct the metadata nodes that contains any hints,
401+ // stripping any unrecognized metadata.
402+ ArrayRef<MDOperand> Operands = BaseMD->operands ();
403+ for (auto &Op : Operands.drop_front ())
404+ if (isLoopMDCompatible (M, Op.get ()))
405+ CompatibleOperands.push_back (Op.get ());
406+
407+ if (2 < CompatibleOperands.size ())
408+ reportLoopError (M, " Provided conflicting hints" );
409+
410+ MDNode *CompatibleLoopMD = MDNode::get (M.getContext (), CompatibleOperands);
411+ TempNode->replaceAllUsesWith (CompatibleLoopMD);
412+
413+ I->setMetadata (" llvm.loop" , CompatibleLoopMD);
405414}
406415
407416using InstructionMDList = std::array<unsigned , 7 >;
@@ -427,10 +436,9 @@ static void translateInstructionMetadata(Module &M) {
427436 translateBranchMetadata (M, I);
428437
429438 for (auto &I : make_early_inc_range (BB)) {
430- if (isa<BranchInst>(I)) {
439+ if (isa<BranchInst>(I))
431440 if (MDNode *LoopMD = I.getMetadata (MDLoopKind))
432- translateLoopMetadata (M, LoopMD);
433- }
441+ translateLoopMetadata (M, &I, LoopMD);
434442 I.dropUnknownNonDebugMetadata (DXILCompatibleMDs);
435443 }
436444 }
0 commit comments