Skip to content

Commit 99e6a9d

Browse files
Merge pull request #17 from CompilerProgramming/chaitin
The graph coloring register allocator now pre-assigns function args
2 parents 2091a70 + 13f01c6 commit 99e6a9d

File tree

10 files changed

+211
-30
lines changed

10 files changed

+211
-30
lines changed

optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ChaitinGraphColoringRegisterAllocator.java

Lines changed: 70 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,78 @@
44
import java.util.stream.IntStream;
55

66
/**
7-
* Implement the original graph coloring algorithm described by Chaitin.
7+
* Implements the original graph coloring algorithm described by Chaitin.
8+
* Since we are targeting an abstract machine where there are no limits on
9+
* number of registers except how we set them, our goal here is to get to
10+
* the minimum number of registers required to execute the function.
11+
* <p>
12+
* We do want to implement spilling even though we do not need it for the
13+
* abstract machine, but it is not yet implemented. We would spill to a
14+
* stack attached to the abstract machine.
815
*
916
* TODO spilling
1017
*/
1118
public class ChaitinGraphColoringRegisterAllocator {
1219

13-
public ChaitinGraphColoringRegisterAllocator() {
14-
}
15-
1620
public Map<Integer, Integer> assignRegisters(CompiledFunction function, int numRegisters) {
1721
if (function.isSSA) throw new IllegalStateException("Register allocation should be done after exiting SSA");
18-
var g = coalesce(function);
19-
var registers = registersInIR(function);
20-
var colors = IntStream.range(0, numRegisters).boxed().toList();
21-
// TODO pre-assign regs to args
22+
// Remove useless copy operations
23+
InterferenceGraph g = coalesce(function);
24+
// Get used registers
25+
Set<Integer> registers = registersInIR(function);
26+
// Create color set
27+
List<Integer> colors = new ArrayList<>(IntStream.range(0, numRegisters).boxed().toList());
28+
// Function args are pre-assigned colors
29+
// and we remove them from the register set
30+
Map<Integer, Integer> assignments = preAssignArgsToColors(function, registers, colors);
2231
// TODO spilling
23-
var assignments = colorGraph(g, registers, new HashSet<>(colors));
32+
// execute graph coloring on remaining registers
33+
assignments = colorGraph(g, registers, new HashSet<>(colors), assignments);
34+
// update all instructions
35+
// We simply set the slot on each register - rather than actually trying to replace them
36+
updateInstructions(function, assignments);
37+
// Compute and set the new framesize
38+
function.setFrameSize(computeFrameSize(assignments));
2439
return assignments;
2540
}
2641

42+
/**
43+
* Frame size = max number of registers needed to execute the function
44+
*/
45+
private int computeFrameSize(Map<Integer, Integer> assignments) {
46+
return assignments.values().stream().mapToInt(k->k).max().orElse(0);
47+
}
48+
49+
/**
50+
* Due to the way function args are received by the abstract machine, we need
51+
* to assign them register slots starting from 0. After assigning colors/slots
52+
* we remove these from the set so that the graph coloring algo does
53+
*/
54+
private Map<Integer, Integer> preAssignArgsToColors(CompiledFunction function, Set<Integer> registers, List<Integer> colors) {
55+
int count = 0;
56+
Map<Integer, Integer> assignments = new HashMap<>();
57+
for (Instruction instruction : function.entry.instructions) {
58+
if (instruction instanceof Instruction.ArgInstruction argInstruction) {
59+
Integer color = colors.get(count);
60+
Register reg = argInstruction.arg().reg;
61+
registers.remove(reg.nonSSAId()); // Remove register from set before changing slot
62+
assignments.put(reg.nonSSAId(), color);
63+
count++;
64+
}
65+
else break;
66+
}
67+
return assignments;
68+
}
69+
70+
private void updateInstructions(CompiledFunction function, Map<Integer, Integer> assignments) {
71+
var regPool = function.registerPool;
72+
for (var entry : assignments.entrySet()) {
73+
int reg = entry.getKey();
74+
int slot = entry.getValue();
75+
regPool.getReg(reg).updateSlot(slot);
76+
}
77+
}
78+
2779
/**
2880
* Chaitin: coalesce_nodes - coalesce away copy operations
2981
*/
@@ -85,9 +137,7 @@ private void rewriteInstructions(CompiledFunction function, Instruction deadInst
85137
private Set<Integer> registersInIR(CompiledFunction function) {
86138
Set<Integer> registers = new HashSet<>();
87139
for (var block: function.getBlocks()) {
88-
Iterator<Instruction> iter = block.instructions.iterator();
89-
while (iter.hasNext()) {
90-
Instruction instruction = iter.next();
140+
for (Instruction instruction: block.instructions) {
91141
if (instruction.definesVar())
92142
registers.add(instruction.def().id);
93143
for (Register use: instruction.uses())
@@ -112,7 +162,7 @@ private Integer findNodeWithNeighborCountLessThan(InterferenceGraph g, Set<Integ
112162
private Set<Integer> getNeighborColors(InterferenceGraph g, Integer node, Map<Integer,Integer> assignedColors) {
113163
Set<Integer> colors = new HashSet<>();
114164
for (var neighbour: g.neighbors(node)) {
115-
var c = assignedColors.get(neighbour);
165+
Integer c = assignedColors.get(neighbour);
116166
if (c != null) {
117167
colors.add(c);
118168
}
@@ -137,18 +187,18 @@ private static HashSet<Integer> subtract(Set<Integer> originalSet, Integer node)
137187
/**
138188
* Chaitin: color_graph
139189
*/
140-
private Map<Integer, Integer> colorGraph(InterferenceGraph g, Set<Integer> nodes, Set<Integer> colors) {
190+
private Map<Integer, Integer> colorGraph(InterferenceGraph g, Set<Integer> nodes, Set<Integer> colors, Map<Integer, Integer> preAssignedColors) {
141191
if (nodes.size() == 0)
142-
return new HashMap<>();
143-
var numColors = colors.size();
144-
var node = findNodeWithNeighborCountLessThan(g, nodes, numColors);
192+
return preAssignedColors;
193+
int numColors = colors.size();
194+
Integer node = findNodeWithNeighborCountLessThan(g, nodes, numColors);
145195
if (node == null)
146196
return null;
147-
var coloring = colorGraph(g.dup().subtract(node), subtract(nodes, node), colors);
197+
Map<Integer, Integer> coloring = colorGraph(g.dup().subtract(node), subtract(nodes, node), colors, preAssignedColors);
148198
if (coloring == null)
149199
return null;
150-
var neighbourColors = getNeighborColors(g, node, coloring);
151-
var color = chooseSomeColorNotAssignedToNeighbors(colors, neighbourColors);
200+
Set<Integer> neighbourColors = getNeighborColors(g, node, coloring);
201+
Integer color = chooseSomeColorNotAssignedToNeighbors(colors, neighbourColors);
152202
coloring.put(node, color);
153203
return coloring;
154204
}

optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public class CompiledFunction {
2121
private Type.TypeFunction functionType;
2222
public final RegisterPool registerPool;
2323

24-
private final int frameSlots;
24+
private int frameSlots;
2525

2626
public boolean isSSA;
2727
public boolean hasLiveness;
@@ -76,6 +76,9 @@ private void generateArgInstructions(Scope scope) {
7676
public int frameSize() {
7777
return frameSlots;
7878
}
79+
public void setFrameSize(int size) {
80+
frameSlots = size;
81+
}
7982

8083
private void exitBlockIfNeeded() {
8184
if (currentBlock != null &&
@@ -134,6 +137,7 @@ private void compileStatement(AST.Stmt statement) {
134137
case AST.VarStmt letStmt -> {
135138
compileLet(letStmt);
136139
}
140+
case AST.VarDeclStmt varDeclStmt -> {}
137141
case AST.IfElseStmt ifElseStmt -> {
138142
compileIf(ifElseStmt);
139143
}

optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Compiler.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,32 @@
88
import com.compilerprogramming.ezlang.types.Type;
99
import com.compilerprogramming.ezlang.types.TypeDictionary;
1010

11-
import java.util.BitSet;
12-
1311
public class Compiler {
1412

15-
private void compile(TypeDictionary typeDictionary) {
13+
private void compile(TypeDictionary typeDictionary, boolean opt) {
1614
for (Symbol symbol: typeDictionary.getLocalSymbols()) {
1715
if (symbol instanceof Symbol.FunctionTypeSymbol functionSymbol) {
1816
Type.TypeFunction functionType = (Type.TypeFunction) functionSymbol.type;
19-
functionType.code = new CompiledFunction(functionSymbol);
17+
var function = new CompiledFunction(functionSymbol);
18+
functionType.code = function;
19+
if (opt) {
20+
new Optimizer().optimize(function);
21+
}
2022
}
2123
}
2224
}
2325
public TypeDictionary compileSrc(String src) {
26+
return compileSrc(src, false);
27+
}
28+
public TypeDictionary compileSrc(String src, boolean opt) {
2429
Parser parser = new Parser();
2530
var program = parser.parse(new Lexer(src));
2631
var typeDict = new TypeDictionary();
2732
var sema = new SemaDefineTypes(typeDict);
2833
sema.analyze(program);
2934
var sema2 = new SemaAssignTypes(typeDict);
3035
sema2.analyze(program);
31-
compile(typeDict);
36+
compile(typeDict, opt);
3237
return typeDict;
3338
}
3439
public static String dumpIR(TypeDictionary typeDictionary) {

optvm/src/main/java/com/compilerprogramming/ezlang/compiler/EnterSSA.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package com.compilerprogramming.ezlang.compiler;
22

3+
import com.compilerprogramming.ezlang.exceptions.CompilerException;
4+
35
import java.util.*;
46

57
/**
@@ -178,7 +180,11 @@ static class BBSet {
178180
static class VersionStack {
179181
List<Register.SSARegister> stack = new ArrayList<>();
180182
void push(Register.SSARegister r) { stack.add(r); }
181-
Register.SSARegister top() { return stack.getLast(); }
183+
Register.SSARegister top() {
184+
if (stack.isEmpty())
185+
throw new CompilerException("Variable may not be initialized");
186+
return stack.getLast();
187+
}
182188
void pop() { stack.removeLast(); }
183189
}
184190

optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ public ExitSSA(CompiledFunction function) {
2323
initStack();
2424
insertCopies(function.entry);
2525
removePhis();
26+
function.isSSA = false;
2627
}
2728

2829
private void removePhis() {

optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Operand.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public String toString() {
2121
}
2222

2323
public static class RegisterOperand extends Operand {
24-
public Register reg;
24+
Register reg;
2525
public RegisterOperand(Register reg) {
2626
this.reg = reg;
2727
if (reg == null)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package com.compilerprogramming.ezlang.compiler;
2+
3+
public class Optimizer {
4+
5+
public void optimize(CompiledFunction function) {
6+
new EnterSSA(function);
7+
new ExitSSA(function);
8+
new ChaitinGraphColoringRegisterAllocator().assignRegisters(function, 64);
9+
}
10+
}

optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Register.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ public class Register {
2121
* The type of a register
2222
*/
2323
public final Type type;
24+
private int slot;
2425

2526
public Register(int id, String name, Type type) {
2627
this.id = id;
2728
this.name = name;
2829
this.type = type;
30+
this.slot = id;
2931
}
3032
@Override
3133
public boolean equals(Object o) {
@@ -44,7 +46,10 @@ public String name() {
4446
return name;
4547
}
4648
public int nonSSAId() {
47-
return id;
49+
return slot;
50+
}
51+
public void updateSlot(int slot) {
52+
this.slot = slot;
4853
}
4954

5055
/**

optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import com.compilerprogramming.ezlang.types.Type;
55
import com.compilerprogramming.ezlang.types.TypeDictionary;
66
import org.junit.Assert;
7+
import org.junit.Ignore;
78
import org.junit.Test;
89

910
import java.util.Arrays;
@@ -766,4 +767,51 @@ public void testSwapProblem() {
766767
Assert.assertEquals(expected, function.toStr(new StringBuilder(), false).toString());
767768
}
768769

770+
@Test
771+
public void testLiveness() {
772+
String src = """
773+
func bar(x: Int)->Int {
774+
var y = 0
775+
var z = 0
776+
while( x>1 ){
777+
y = x/2;
778+
if (y > 3) {
779+
x = x-y;
780+
}
781+
z = x-4;
782+
if (z > 0) {
783+
x = x/2;
784+
}
785+
z = z-1;
786+
}
787+
return x;
788+
}
789+
790+
func foo() {
791+
return bar(10);
792+
}
793+
""";
794+
String result = compileSrc(src);
795+
System.out.println(result);
796+
}
797+
798+
@Test
799+
@Ignore
800+
public void testInit() {
801+
// see issue #16
802+
String src = """
803+
func foo(x: Int) {
804+
var z: Int
805+
while (x > 0) {
806+
z = 5
807+
if (x == 1)
808+
z = z+1
809+
x = x - 1
810+
}
811+
}
812+
""";
813+
String result = compileSrc(src);
814+
System.out.println(result);
815+
}
816+
769817
}

0 commit comments

Comments
 (0)