2121#include " mlir/IR/PatternMatch.h"
2222#include " mlir/Transforms/DialectConversion.h"
2323#include " mlir/Transforms/Passes.h"
24+ #include " llvm/Support/LogicalResult.h"
2425
2526namespace mlir {
2627#define GEN_PASS_DEF_SCFTOEMITC
@@ -106,7 +107,7 @@ static void assignValues(ValueRange values, ValueRange variables,
106107 emitc::AssignOp::create (rewriter, loc, var, value);
107108}
108109
109- SmallVector<Value> loadValues (const SmallVector <Value> & variables,
110+ SmallVector<Value> loadValues (ArrayRef <Value> variables,
110111 PatternRewriter &rewriter, Location loc) {
111112 return llvm::map_to_vector<>(variables, [&](Value var) {
112113 Type type = cast<emitc::LValueType>(var.getType ()).getValueType ();
@@ -116,16 +117,15 @@ SmallVector<Value> loadValues(const SmallVector<Value> &variables,
116117
117118static LogicalResult lowerYield (Operation *op, ValueRange resultVariables,
118119 ConversionPatternRewriter &rewriter,
119- scf::YieldOp yield) {
120+ scf::YieldOp yield, bool createYield = true ) {
120121 Location loc = yield.getLoc ();
121122
122123 OpBuilder::InsertionGuard guard (rewriter);
123124 rewriter.setInsertionPoint (yield);
124125
125126 SmallVector<Value> yieldOperands;
126- if (failed (rewriter.getRemappedValues (yield.getOperands (), yieldOperands))) {
127+ if (failed (rewriter.getRemappedValues (yield.getOperands (), yieldOperands)))
127128 return rewriter.notifyMatchFailure (op, " failed to lower yield operands" );
128- }
129129
130130 assignValues (yieldOperands, resultVariables, rewriter, loc);
131131
@@ -336,11 +336,177 @@ LogicalResult IndexSwitchOpLowering::matchAndRewrite(
336336 return success ();
337337}
338338
339+ // Lower scf::while to emitc::do using mutable variables to maintain loop state
340+ // across iterations. The do-while structure ensures the condition is evaluated
341+ // after each iteration, matching SCF while semantics.
342+ struct WhileLowering : public OpConversionPattern <WhileOp> {
343+ using OpConversionPattern::OpConversionPattern;
344+
345+ LogicalResult
346+ matchAndRewrite (WhileOp whileOp, OpAdaptor adaptor,
347+ ConversionPatternRewriter &rewriter) const override {
348+ Location loc = whileOp.getLoc ();
349+ MLIRContext *context = loc.getContext ();
350+
351+ // Create an emitc::variable op for each result. These variables will be
352+ // assigned to by emitc::assign ops within the loop body.
353+ SmallVector<Value> resultVariables;
354+ if (failed (createVariablesForResults (whileOp, getTypeConverter (), rewriter,
355+ resultVariables)))
356+ return rewriter.notifyMatchFailure (whileOp,
357+ " Failed to create result variables" );
358+
359+ // Create variable storage for loop-carried values to enable imperative
360+ // updates while maintaining SSA semantics at conversion boundaries.
361+ SmallVector<Value> loopVariables;
362+ if (failed (createVariablesForLoopCarriedValues (
363+ whileOp, rewriter, loopVariables, loc, context)))
364+ return failure ();
365+
366+ if (failed (lowerDoWhile (whileOp, loopVariables, resultVariables, context,
367+ rewriter, loc)))
368+ return failure ();
369+
370+ rewriter.setInsertionPointAfter (whileOp);
371+
372+ // Load the final result values from result variables.
373+ SmallVector<Value> finalResults =
374+ loadValues (resultVariables, rewriter, loc);
375+ rewriter.replaceOp (whileOp, finalResults);
376+
377+ return success ();
378+ }
379+
380+ private:
381+ // Initialize variables for loop-carried values to enable state updates
382+ // across iterations without SSA argument passing.
383+ LogicalResult createVariablesForLoopCarriedValues (
384+ WhileOp whileOp, ConversionPatternRewriter &rewriter,
385+ SmallVectorImpl<Value> &loopVars, Location loc,
386+ MLIRContext *context) const {
387+ OpBuilder::InsertionGuard guard (rewriter);
388+ rewriter.setInsertionPoint (whileOp);
389+
390+ emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get (context, " " );
391+
392+ for (Value init : whileOp.getInits ()) {
393+ Type convertedType = getTypeConverter ()->convertType (init.getType ());
394+ if (!convertedType)
395+ return rewriter.notifyMatchFailure (whileOp, " type conversion failed" );
396+
397+ emitc::VariableOp var = rewriter.create <emitc::VariableOp>(
398+ loc, emitc::LValueType::get (convertedType), noInit);
399+ rewriter.create <emitc::AssignOp>(loc, var.getResult (), init);
400+ loopVars.push_back (var);
401+ }
402+
403+ return success ();
404+ }
405+
406+ // Lower scf.while to emitc.do.
407+ LogicalResult lowerDoWhile (WhileOp whileOp, ArrayRef<Value> loopVars,
408+ ArrayRef<Value> resultVars, MLIRContext *context,
409+ ConversionPatternRewriter &rewriter,
410+ Location loc) const {
411+ // Create a global boolean variable to store the loop condition state.
412+ Type i1Type = IntegerType::get (context, 1 );
413+ auto globalCondition =
414+ rewriter.create <emitc::VariableOp>(loc, emitc::LValueType::get (i1Type),
415+ emitc::OpaqueAttr::get (context, " " ));
416+ Value conditionVal = globalCondition.getResult ();
417+
418+ auto loweredDo = rewriter.create <emitc::DoOp>(loc);
419+
420+ // Convert region types to match the target dialect type system.
421+ if (failed (rewriter.convertRegionTypes (&whileOp.getBefore (),
422+ *getTypeConverter (), nullptr )) ||
423+ failed (rewriter.convertRegionTypes (&whileOp.getAfter (),
424+ *getTypeConverter (), nullptr ))) {
425+ return rewriter.notifyMatchFailure (whileOp,
426+ " region types conversion failed" );
427+ }
428+
429+ // Prepare the before region (condition evaluation) for merging.
430+ Block *beforeBlock = &whileOp.getBefore ().front ();
431+ Block *bodyBlock = rewriter.createBlock (&loweredDo.getBodyRegion ());
432+ rewriter.setInsertionPointToStart (bodyBlock);
433+
434+ // Load current variable values to use as initial arguments for the
435+ // condition block.
436+ SmallVector<Value> replacingValues = loadValues (loopVars, rewriter, loc);
437+ rewriter.mergeBlocks (beforeBlock, bodyBlock, replacingValues);
438+
439+ Operation *condTerminator =
440+ loweredDo.getBodyRegion ().back ().getTerminator ();
441+ scf::ConditionOp condOp = cast<scf::ConditionOp>(condTerminator);
442+ rewriter.setInsertionPoint (condOp);
443+
444+ // Update result variables with values from scf::condition.
445+ SmallVector<Value> conditionArgs;
446+ for (Value arg : condOp.getArgs ()) {
447+ conditionArgs.push_back (rewriter.getRemappedValue (arg));
448+ }
449+ assignValues (conditionArgs, resultVars, rewriter, loc);
450+
451+ // Convert scf.condition to condition variable assignment.
452+ Value condition = rewriter.getRemappedValue (condOp.getCondition ());
453+ rewriter.create <emitc::AssignOp>(loc, conditionVal, condition);
454+
455+ // Wrap body region in conditional to preserve scf semantics. Only create
456+ // ifOp if after-region is non-empty.
457+ if (whileOp.getAfterBody ()->getOperations ().size () > 1 ) {
458+ auto ifOp = rewriter.create <emitc::IfOp>(loc, condition, false , false );
459+
460+ // Prepare the after region (loop body) for merging.
461+ Block *afterBlock = &whileOp.getAfter ().front ();
462+ Block *ifBodyBlock = rewriter.createBlock (&ifOp.getBodyRegion ());
463+
464+ // Replacement values for after block using condition op arguments.
465+ SmallVector<Value> afterReplacingValues;
466+ for (Value arg : condOp.getArgs ())
467+ afterReplacingValues.push_back (rewriter.getRemappedValue (arg));
468+
469+ rewriter.mergeBlocks (afterBlock, ifBodyBlock, afterReplacingValues);
470+
471+ if (failed (lowerYield (whileOp, loopVars, rewriter,
472+ cast<scf::YieldOp>(ifBodyBlock->getTerminator ()))))
473+ return failure ();
474+ }
475+
476+ rewriter.eraseOp (condOp);
477+
478+ // Create condition region that loads from the flag variable.
479+ Region &condRegion = loweredDo.getConditionRegion ();
480+ Block *condBlock = rewriter.createBlock (&condRegion);
481+ rewriter.setInsertionPointToStart (condBlock);
482+
483+ auto exprOp = rewriter.create <emitc::ExpressionOp>(
484+ loc, i1Type, conditionVal, /* do_not_inline=*/ false );
485+ Block *exprBlock = rewriter.createBlock (&exprOp.getBodyRegion ());
486+
487+ // Set up the expression block to load the condition variable.
488+ exprBlock->addArgument (conditionVal.getType (), loc);
489+ rewriter.setInsertionPointToStart (exprBlock);
490+
491+ // Load the condition value and yield it as the expression result.
492+ Value cond =
493+ rewriter.create <emitc::LoadOp>(loc, i1Type, exprBlock->getArgument (0 ));
494+ rewriter.create <emitc::YieldOp>(loc, cond);
495+
496+ // Yield the expression as the condition region result.
497+ rewriter.setInsertionPointToEnd (condBlock);
498+ rewriter.create <emitc::YieldOp>(loc, exprOp);
499+
500+ return success ();
501+ }
502+ };
503+
339504void mlir::populateSCFToEmitCConversionPatterns (RewritePatternSet &patterns,
340505 TypeConverter &typeConverter) {
341506 patterns.add <ForLowering>(typeConverter, patterns.getContext ());
342507 patterns.add <IfLowering>(typeConverter, patterns.getContext ());
343508 patterns.add <IndexSwitchOpLowering>(typeConverter, patterns.getContext ());
509+ patterns.add <WhileLowering>(typeConverter, patterns.getContext ());
344510}
345511
346512void SCFToEmitCPass::runOnOperation () {
@@ -357,7 +523,8 @@ void SCFToEmitCPass::runOnOperation() {
357523
358524 // Configure conversion to lower out SCF operations.
359525 ConversionTarget target (getContext ());
360- target.addIllegalOp <scf::ForOp, scf::IfOp, scf::IndexSwitchOp>();
526+ target
527+ .addIllegalOp <scf::ForOp, scf::IfOp, scf::IndexSwitchOp, scf::WhileOp>();
361528 target.markUnknownOpDynamicallyLegal ([](Operation *) { return true ; });
362529 if (failed (
363530 applyPartialConversion (getOperation (), target, std::move (patterns))))
0 commit comments