diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java index caf691e..4028bfd 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java @@ -15,7 +15,7 @@ public class CompiledFunction { public BasicBlock entry; public BasicBlock exit; private int bid = 0; - private BasicBlock currentBlock; + public BasicBlock currentBlock; private BasicBlock currentBreakTarget; private BasicBlock currentContinueTarget; private Type.TypeFunction functionType; @@ -52,6 +52,17 @@ public CompiledFunction(Symbol.FunctionTypeSymbol functionSymbol) { this.frameSlots = registerPool.numRegisters(); } + public CompiledFunction(Type.TypeFunction functionType) { + this.functionType = (Type.TypeFunction) functionType; + this.registerPool = new RegisterPool("%ret", functionType == null?null:functionType.returnType); + this.bid = 0; + this.entry = this.currentBlock = createBlock(); + this.exit = createBlock(); + this.currentBreakTarget = null; + this.currentContinueTarget = null; + this.frameSlots = registerPool.numRegisters(); + } + private void generateArgInstructions(Scope scope) { if (scope.isFunctionParameterScope) { for (Symbol symbol: scope.getLocalSymbols()) { @@ -84,7 +95,7 @@ private void setVirtualRegisters(Scope scope) { } } - private BasicBlock createBlock() { + public BasicBlock createBlock() { return new BasicBlock(bid++); } @@ -111,7 +122,7 @@ else if (virtualStack.size() > 1) jumpTo(exit); } - private void code(Instruction instruction) { + public void code(Instruction instruction) { currentBlock.add(instruction); } @@ -213,13 +224,13 @@ private boolean isBlockTerminated(BasicBlock block) { block.instructions.getLast().isTerminal()); } - private void jumpTo(BasicBlock block) { + public void jumpTo(BasicBlock block) { assert !isBlockTerminated(currentBlock); currentBlock.add(new Instruction.Jump(block)); currentBlock.addSuccessor(block); } - private void startBlock(BasicBlock block) { + public void startBlock(BasicBlock block) { if (!isBlockTerminated(currentBlock)) { jumpTo(block); } @@ -541,11 +552,12 @@ private boolean vstackEmpty() { return virtualStack.isEmpty(); } - public void toStr(StringBuilder sb, boolean verbose) { + public StringBuilder toStr(StringBuilder sb, boolean verbose) { if (verbose) { sb.append(this.functionType.describe()).append("\n"); registerPool.toStr(sb); } BasicBlock.toStr(sb, entry, new BitSet(), verbose); + return sb; } } diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java index 310a9f2..a81aef7 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java @@ -34,7 +34,7 @@ private void removePhis() { /* Algorithm for iterating through blocks to perform phi replacement */ private void insertCopies(BasicBlock block) { - List pushed = new ArrayList<>(); + List pushed = new ArrayList<>(); for (Instruction i: block.instructions) { // replace all uses u with stacks[i] if (i.usesVars()) { @@ -45,8 +45,8 @@ private void insertCopies(BasicBlock block) { for (BasicBlock c: block.dominatedChildren) { insertCopies(c); } - for (Register name: pushed) { - stacks[name.id].pop(); + for (Integer name: pushed) { + stacks[name].pop(); } } @@ -67,8 +67,8 @@ private void replaceUses(Instruction i) { } static class CopyItem { - Register src; - Register dest; + final Register src; + final Register dest; boolean removed; public CopyItem(Register src, Register dest) { @@ -78,7 +78,7 @@ public CopyItem(Register src, Register dest) { } } - private void scheduleCopies(BasicBlock block, List pushed) { + private void scheduleCopies(BasicBlock block, List pushed) { /* Pass 1 - Initialize data structures */ /* In this pass we count the number of times a name is used by other phi-nodes */ List copySet = new ArrayList<>(); @@ -100,7 +100,7 @@ private void scheduleCopies(BasicBlock block, List pushed) { /* In this pass we build a worklist of names that are not used in other phi nodes */ List workList = new ArrayList<>(); for (CopyItem copyItem: copySet) { - if (!usedByAnother.get(copyItem.dest.id)) { + if (usedByAnother.get(copyItem.dest.id) != true) { copyItem.removed = true; workList.add(copyItem); } @@ -112,17 +112,17 @@ private void scheduleCopies(BasicBlock block, List pushed) { /* Each time we insert a copy operation we add the source of that op to the worklist */ while (!workList.isEmpty() || !copySet.isEmpty()) { while (!workList.isEmpty()) { - CopyItem copyItem = workList.remove(0); - Register src = copyItem.src; - Register dest = copyItem.dest; + final CopyItem copyItem = workList.remove(0); + final Register src = copyItem.src; + final Register dest = copyItem.dest; if (block.liveOut.get(dest.id)) { /* Insert a copy from dest to a new temp t at phi node defining dest */ - Register t = insertCopy(block, dest); + final Register t = addMoveToTempAfterPhi(block, dest); stacks[dest.id].push(t); - pushed.add(t); + pushed.add(dest.id); } - /* Insert a copy operation from map[src] to dst at end of BB */ - appendCopy(block, map.get(src.id), dest); + /* Insert a copy operation from map[src] to dest at end of BB */ + addMoveAtBBEnd(block, map.get(src.id), dest); map.put(src.id, dest); /* If src is the name of a dest in copySet add item to worklist */ CopyItem item = isDest(copySet, src); @@ -133,7 +133,7 @@ private void scheduleCopies(BasicBlock block, List pushed) { if (!copySet.isEmpty()) { CopyItem copyItem = copySet.remove(0); /* Insert a copy from dst to new temp at the end of Block */ - Register t = appendCopy(block, copyItem.dest); + Register t = addMoveToTempAtBBEnd(block, copyItem.dest); map.put(copyItem.dest.id, t); workList.add(copyItem); } @@ -142,30 +142,33 @@ private void scheduleCopies(BasicBlock block, List pushed) { private void insertAtEnd(BasicBlock bb, Instruction i) { assert bb.instructions.size() > 0; + // Last instruction is a branch - so new instruction will + // go before that int pos = bb.instructions.size()-1; bb.instructions.add(pos, i); } private void insertAfterPhi(BasicBlock bb, Register phiDef, Instruction newInst) { assert bb.instructions.size() > 0; - int pos = 0; - while (pos < bb.instructions.size()) { - Instruction i = bb.instructions.get(pos++); - if (i instanceof Instruction.Phi phi && - phi.def().id == phiDef.id) { - break; + int insertionPos = -1; + for (int pos = 0; pos < bb.instructions.size(); pos++) { + Instruction i = bb.instructions.get(pos); + if (i instanceof Instruction.Phi phi) { + if (phi.def().id == phiDef.id) { + insertionPos = pos+1; // After phi + break; + } } } - if (pos == bb.instructions.size()) { + if (insertionPos < 0) { throw new IllegalStateException(); } - bb.instructions.add(pos, newInst); + bb.instructions.add(insertionPos, newInst); } - /* Insert a copy from dest to new temp at end of BB, and return temp */ - private Register appendCopy(BasicBlock block, Register dest) { - var temp = function.registerPool.newReg(dest.name(), dest.type); + private Register addMoveToTempAtBBEnd(BasicBlock block, Register dest) { + var temp = function.registerPool.newTempReg(dest.name(), dest.type); var inst = new Instruction.Move(new Operand.RegisterOperand(dest), new Operand.RegisterOperand(temp)); insertAtEnd(block, inst); return temp; @@ -181,16 +184,16 @@ private CopyItem isDest(List copySet, Register src) { } /* Insert a copy from src to dst at end of BB */ - private void appendCopy(BasicBlock block, Register src, Register dest) { + private void addMoveAtBBEnd(BasicBlock block, Register src, Register dest) { var inst = new Instruction.Move(new Operand.RegisterOperand(src), new Operand.RegisterOperand(dest)); insertAtEnd(block, inst); } /* Insert a copy dest to a new temp at phi node defining dest, return temp */ - private Register insertCopy(BasicBlock block, Register dst) { - var temp = function.registerPool.newReg(dst.name(), dst.type); - var inst = new Instruction.Move(new Operand.RegisterOperand(dst), new Operand.RegisterOperand(temp)); - insertAfterPhi(block, dst, inst); + private Register addMoveToTempAfterPhi(BasicBlock block, Register dest) { + var temp = function.registerPool.newTempReg(dest.name(), dest.type); + var inst = new Instruction.Move(new Operand.RegisterOperand(dest), new Operand.RegisterOperand(temp)); + insertAfterPhi(block, dest, inst); return temp; } diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/RegisterPool.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/RegisterPool.java index 7a011e0..23ec8f3 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/RegisterPool.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/RegisterPool.java @@ -38,6 +38,13 @@ public Register newTempReg(Type type) { registers.add(reg); return reg; } + public Register newTempReg(String baseName, Type type) { + var id = registers.size(); + var name = baseName+"_"+id; + var reg = new Register(id, name, type); + registers.add(reg); + return reg; + } public Register.SSARegister ssaReg(Register original, int version) { var id = registers.size(); diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java index f690a36..a2b5091 100644 --- a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java @@ -1,9 +1,12 @@ package com.compilerprogramming.ezlang.compiler; import com.compilerprogramming.ezlang.types.Symbol; +import com.compilerprogramming.ezlang.types.Type; +import com.compilerprogramming.ezlang.types.TypeDictionary; import org.junit.Assert; import org.junit.Test; +import java.util.Arrays; import java.util.BitSet; public class TestSSATransform { @@ -625,4 +628,144 @@ func bar(arg: Int)->Int { goto L1 """, result); } + + /** + * This test case is based on the example snippet from Briggs paper + * illustrating the lost copy problem. + */ + private CompiledFunction buildLostCopyTest() { + TypeDictionary typeDictionary = new TypeDictionary(); + Type.TypeFunction functionType = new Type.TypeFunction("foo"); + functionType.addArg(new Symbol.ParameterSymbol("p", typeDictionary.INT)); + functionType.setReturnType(typeDictionary.INT); + CompiledFunction function = new CompiledFunction(functionType); + RegisterPool regPool = function.registerPool; + Register p = regPool.newReg("p", typeDictionary.INT); + Register x1 = regPool.newReg("x1", typeDictionary.INT); + function.code(new Instruction.ArgInstruction(new Operand.LocalRegisterOperand(p))); + function.code(new Instruction.Move( + new Operand.ConstantOperand(1, typeDictionary.INT), + new Operand.RegisterOperand(x1))); + BasicBlock B2 = function.createBlock(); + function.startBlock(B2); + Register x3 = regPool.newReg("x3", typeDictionary.INT); + Register x2 = regPool.newReg("x2", typeDictionary.INT); + function.code(new Instruction.Phi(x2, Arrays.asList(x1, x3))); + function.code(new Instruction.Binary("+", + new Operand.RegisterOperand(x3), + new Operand.RegisterOperand(x2), + new Operand.ConstantOperand(1, typeDictionary.INT))); + function.code(new Instruction.ConditionalBranch(B2, + new Operand.RegisterOperand(p), B2, function.exit)); + function.startBlock(function.exit); + function.code(new Instruction.Return(new Operand.RegisterOperand(x2), regPool.returnRegister)); + function.isSSA = true; + return function; + } + + @Test + public void testLostCopyProblem() { + CompiledFunction function = buildLostCopyTest(); + String expected = """ +L0: + arg p + x1 = 1 + goto L2 +L2: + x2 = phi(x1, x3) + x3 = x2+1 + if p goto L2 else goto L1 +L1: + %ret = x2 +"""; + Assert.assertEquals(expected, function.toStr(new StringBuilder(), false).toString()); + new ExitSSA(function); + expected = """ +L0: + arg p + x1 = 1 + x2 = x1 + goto L2 +L2: + x2_5 = x2 + x3 = x2+1 + x2 = x3 + if p goto L2 else goto L1 +L1: + %ret = x2_5 +"""; + Assert.assertEquals(expected, function.toStr(new StringBuilder(), false).toString()); + } + + /** + * This test case is based on the example snippet from Briggs paper + * illustrating the swap problem. + */ + private CompiledFunction buildSwapTest() { + TypeDictionary typeDictionary = new TypeDictionary(); + Type.TypeFunction functionType = new Type.TypeFunction("foo"); + functionType.addArg(new Symbol.ParameterSymbol("p", typeDictionary.INT)); + CompiledFunction function = new CompiledFunction(functionType); + RegisterPool regPool = function.registerPool; + Register p = regPool.newReg("p", typeDictionary.INT); + Register a1 = regPool.newReg("a1", typeDictionary.INT); + Register a2 = regPool.newReg("a2", typeDictionary.INT); + Register a3 = regPool.newReg("a3", typeDictionary.INT); + Register b1 = regPool.newReg("b1", typeDictionary.INT); + Register b2 = regPool.newReg("b2", typeDictionary.INT); + function.code(new Instruction.ArgInstruction(new Operand.LocalRegisterOperand(p))); + function.code(new Instruction.Move( + new Operand.ConstantOperand(42, typeDictionary.INT), + new Operand.RegisterOperand(a1))); + function.code(new Instruction.Move( + new Operand.ConstantOperand(24, typeDictionary.INT), + new Operand.RegisterOperand(b1))); + BasicBlock B2 = function.createBlock(); + function.startBlock(B2); + function.code(new Instruction.Phi(a2, Arrays.asList(a1, b2))); + function.code(new Instruction.Phi(b2, Arrays.asList(b1, a2))); + function.code(new Instruction.ConditionalBranch(B2, + new Operand.RegisterOperand(p), B2, function.exit)); + function.startBlock(function.exit); + function.isSSA = true; + return function; + } + + @Test + public void testSwapProblem() { + CompiledFunction function = buildSwapTest(); + String expected = """ +L0: + arg p + a1 = 42 + b1 = 24 + goto L2 +L2: + a2 = phi(a1, b2) + b2 = phi(b1, a2) + if p goto L2 else goto L1 +L1: +"""; + Assert.assertEquals(expected, function.toStr(new StringBuilder(), false).toString()); + new ExitSSA(function); + expected = """ +L0: + arg p + a1 = 42 + b1 = 24 + a2 = a1 + b2 = b1 + goto L2 +L2: + a2_6 = a2 + a2 = b2 + b2 = a2_6 + b2_7 = b2 + b2 = b2 + if p goto L2 else goto L1 +L1: +"""; + Assert.assertEquals(expected, function.toStr(new StringBuilder(), false).toString()); + } + } \ No newline at end of file