From 7e320f7b491895f858c287a6ec82112d834044da Mon Sep 17 00:00:00 2001 From: dibyendumajumdar Date: Sat, 15 Feb 2025 14:52:05 +0000 Subject: [PATCH] New pass that detects == operations against constant value in conditional branches - and uses the knowledge to propagate the constant value in the true branch. --- .../ConstantComparisonPropagation.java | 152 ++++++++++++++++++ .../ezlang/compiler/Optimizer.java | 8 +- .../ezlang/compiler/Options.java | 4 +- .../ezlang/compiler/SSAEdges.java | 12 ++ 4 files changed, 174 insertions(+), 2 deletions(-) create mode 100644 optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ConstantComparisonPropagation.java diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ConstantComparisonPropagation.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ConstantComparisonPropagation.java new file mode 100644 index 0000000..f3ef77c --- /dev/null +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ConstantComparisonPropagation.java @@ -0,0 +1,152 @@ +package com.compilerprogramming.ezlang.compiler; + +import java.util.EnumSet; +import java.util.Iterator; +import java.util.Map; + +/** + * The goal of this pass is to detect conditional branching based + * on comparison with a constant. Example + * + *
+ *     if (i == 5)
+ *     {
+ *         // some use of i
+ *     }
+ * 
+ * + * SCCP generates a lattice value for i, and if i changes outside the if block, then + * it is not classed as a constant. This means that inside the if block, we cannot exploit + * the knowledge that i is a constant locally. + * + * To enable this, we need to detect such comparisons and then insert a temp + * variable which is set to the constant value. The new temp variable is inserted + * inside the if block only and does not affect the meaning of i outside the if block. + * After this transformation we can run SCCP again to take advantage of the local + * knowledge. + * + * This technique can be extended to null checks too, but we do not do that yet. + */ +public class ConstantComparisonPropagation { + + private final CompiledFunction function; + private DominatorTree domTree; + private Map ssaDefUse; + private boolean updated = false; + + public ConstantComparisonPropagation(CompiledFunction function) { + this.function = function; + } + + public boolean apply(EnumSet options) { + if (options.contains(Options.CCP)) { + updated = false; + domTree = new DominatorTree(function.entry); + ssaDefUse = SSAEdges.buildDefUseChains(function); + walkBlocks(); + } + return updated; + } + + private void walkBlocks() { + walkBlock(function.entry); + } + + private void walkBlock(BasicBlock block) { + propagateConstantsInComparisons(block); + for (BasicBlock c : block.dominatedChildren) { + walkBlock(c); + } + } + + private void propagateConstantsInComparisons(BasicBlock block) { + if (block == function.exit) return; + // Get terminal instruction + Instruction instruction = block.instructions.getLast(); + if (instruction instanceof Instruction.ConditionalBranch cbr) { + if (cbr.condition() instanceof Operand.RegisterOperand conditionVar) { + SSAEdges.SSADef defUse = ssaDefUse.get(conditionVar.reg); + // If the condition var was result of == with constant value + if (defUse.instruction.block.bid == block.bid + && defUse.instruction instanceof Instruction.Binary binary + && binary.binOp.equals("==")) { + // Get the constant and the register operands + // from the binary + Operand.ConstantOperand constantOp = null; + Operand.RegisterOperand registerOp = null; + if (binary.left() instanceof Operand.ConstantOperand leftConstant + && binary.right() instanceof Operand.RegisterOperand rightReg) { + constantOp = leftConstant; + registerOp = rightReg; + } else if (binary.right() instanceof Operand.ConstantOperand rightConstant + && binary.left() instanceof Operand.RegisterOperand leftReg) { + constantOp = rightConstant; + registerOp = leftReg; + } + if (constantOp != null) { + // Since reg is constant in true branch + // We can replace all uses of reg with the constant + // in blocks dominated by the true block + BasicBlock trueBlock = cbr.trueBlock; + // I am not sure if this scenario can occur but + // for safety we check that the register we will + // replace is not immediately used inside the true block + // in a phi, because we intend to add the definition of + // the replacement after any phis + if (!checkUsedInPhi(trueBlock, registerOp.reg)) { + // Create a temp and move constant to it. + // We insert the new instruction at the top of the True Block, + // where it should dominate all uses of it + // We could just replace with a constant here instead of creating + // a temp and a move instruction but that would be less general a solution as it + // would not work for other scenarios such as null status which we will + // be adding in future + var replacementRegister = function.registerPool.newTempReg(registerOp.reg.type); + var defInst = new Instruction.Move(constantOp, new Operand.TempRegisterOperand(replacementRegister)); + insertAtBeginning(trueBlock, defInst); + var replacementRegisterUses = SSAEdges.addDef(ssaDefUse, replacementRegister, defInst); // Update SSA Def Use chains, add def for new reg + Iterator useIter = ssaDefUse.get(registerOp.reg).useList.iterator(); + while (useIter.hasNext()) { + Instruction use = useIter.next(); + if (trueBlock.dominates(use.block)) { + use.replaceUse(registerOp.reg, replacementRegister); + // Update SSA Def use chains + useIter.remove(); // No longer a use of old register + replacementRegisterUses.addUse(use); // Use of new temp register + } + } + } + } + } + } + } + } + + /* Check if the register is used in a phi instruction within the block */ + private static boolean checkUsedInPhi(BasicBlock block, Register register) { + for (Instruction instruction : block.instructions) { + if (instruction instanceof Instruction.Phi phi) { + for (int i = 0; i < phi.numInputs(); i++) { + if (phi.isRegisterInput(i)) { + Register phiUse = phi.inputAsRegister(i); + if (phiUse.equals(register)) return true; + } + } + } + else break; + } + return false; + } + + /* Insert instruction at the start of BB after any phis */ + private static void insertAtBeginning(BasicBlock block, Instruction instruction) { + int pos = 0; // insertion point + for (; pos < block.instructions.size(); pos++) { + if (!(block.instructions.get(pos) instanceof Instruction.Phi)) + break; + } + if (pos == block.instructions.size()) + throw new IllegalStateException(); + block.add(pos, instruction); + } +} diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Optimizer.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Optimizer.java index 54a63a8..c9e21d8 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Optimizer.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Optimizer.java @@ -7,8 +7,14 @@ public class Optimizer { public void optimize(CompiledFunction function, EnumSet options) { if (options.contains(Options.OPTIMIZE)) { new EnterSSA(function, options); - if (options.contains(Options.SCCP)) + if (options.contains(Options.SCCP)) { new SparseConditionalConstantPropagation().constantPropagation(function).apply(options); + if (new ConstantComparisonPropagation(function).apply(options)) { + // Run SCCP again + // We could repeat this until no further changes occur + new SparseConditionalConstantPropagation().constantPropagation(function).apply(options); + } + } new ExitSSA(function, options); } if (options.contains(Options.REGALLOC)) diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Options.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Options.java index 06a0c42..56280e2 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Options.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Options.java @@ -5,12 +5,14 @@ public enum Options { OPTIMIZE, SCCP, + CCP, // constant comparison propagation REGALLOC, DUMP_INITIAL_IR, DUMP_PRE_SSA_DOMTREE, DUMP_SSA_IR, DUMP_SCCP_PREAPPLY, DUMP_SCCP_POSTAPPLY, + DUMP_CCP_POSTAPPLY, DUMP_SSA_LIVENESS, DUMP_SSA_DOMTREE, DUMP_POST_SSA_IR, @@ -19,6 +21,6 @@ public enum Options { DUMP_POST_CHAITIN_IR; public static final EnumSet NONE = EnumSet.noneOf(Options.class); - public static final EnumSet OPT = EnumSet.of(Options.OPTIMIZE,Options.SCCP,Options.REGALLOC); + public static final EnumSet OPT = EnumSet.of(Options.OPTIMIZE,Options.SCCP,Options.CCP,Options.REGALLOC); public static final EnumSet OPT_VERBOSE = EnumSet.allOf(Options.class); } diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SSAEdges.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SSAEdges.java index ed8b8ea..e6fcdcb 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SSAEdges.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SSAEdges.java @@ -28,6 +28,10 @@ public SSADef(Instruction instruction) { this.instruction = instruction; this.useList = new ArrayList<>(); } + + public void addUse(Instruction instruction) { + useList.add(instruction); + } } public static Map buildDefUseChains(CompiledFunction function) { @@ -41,6 +45,14 @@ public static Map buildDefUseChains(CompiledFunction function) return defUseChains; } + public static SSADef addDef(Map defUseChains, Register register, Instruction instruction) { + if (defUseChains.get(register) != null) + throw new CompilerException("Duplicate definition for register " + register); + var ssaDef = new SSADef(instruction); + defUseChains.put(register, ssaDef); + return ssaDef; + } + private static void recordDefs(CompiledFunction function, Map defUseChains) { for (BasicBlock block : function.getBlocks()) { for (Instruction instruction : block.instructions) {