Skip to content

Commit 852dde6

Browse files
niaowdeadprogram
authored andcommitted
compiler: use Tarjan's SCC algorithm to detect loops for defer
The compiler needs to know whether a defer is in a loop to determine whether to allocate stack or heap memory. Previously, this performed a DFS of the CFG every time a defer was found. This resulted in time complexity jointly proportional to the number of defers and the number of blocks in the function. Now, the compiler will instead use Tarjan's strongly connected components algorithm to find cycles in linear time. The search is performed lazily, so this has minimal performance impact on functions without defers. In order to implement Tarjan's SCC algorithm, additional state needed to be attached to the blocks. I chose to merge all of the per-block state into a single slice to simplify memory management.
1 parent bf61317 commit 852dde6

File tree

6 files changed

+399
-45
lines changed

6 files changed

+399
-45
lines changed

compiler/asserts.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ func (b *builder) createRuntimeAssert(assert llvm.Value, blockPrefix, assertFunc
245245
// current insert position.
246246
faultBlock := b.ctx.AddBasicBlock(b.llvmFn, blockPrefix+".throw")
247247
nextBlock := b.insertBasicBlock(blockPrefix + ".next")
248-
b.blockExits[b.currentBlock] = nextBlock // adjust outgoing block for phi nodes
248+
b.currentBlockInfo.exit = nextBlock // adjust outgoing block for phi nodes
249249

250250
// Now branch to the out-of-bounds or the regular block.
251251
b.CreateCondBr(assert, faultBlock, nextBlock)

compiler/compiler.go

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,12 @@ type builder struct {
152152
llvmFnType llvm.Type
153153
llvmFn llvm.Value
154154
info functionInfo
155-
locals map[ssa.Value]llvm.Value // local variables
156-
blockEntries map[*ssa.BasicBlock]llvm.BasicBlock // a *ssa.BasicBlock may be split up
157-
blockExits map[*ssa.BasicBlock]llvm.BasicBlock // these are the exit blocks
155+
locals map[ssa.Value]llvm.Value // local variables
156+
blockInfo []blockInfo
158157
currentBlock *ssa.BasicBlock
158+
currentBlockInfo *blockInfo
159+
tarjanStack []uint
160+
tarjanIndex uint
159161
phis []phiNode
160162
deferPtr llvm.Value
161163
deferFrame llvm.Value
@@ -187,11 +189,22 @@ func newBuilder(c *compilerContext, irbuilder llvm.Builder, f *ssa.Function) *bu
187189
info: c.getFunctionInfo(f),
188190
locals: make(map[ssa.Value]llvm.Value),
189191
dilocals: make(map[*types.Var]llvm.Metadata),
190-
blockEntries: make(map[*ssa.BasicBlock]llvm.BasicBlock),
191-
blockExits: make(map[*ssa.BasicBlock]llvm.BasicBlock),
192192
}
193193
}
194194

195+
type blockInfo struct {
196+
// entry is the LLVM basic block corresponding to the start of this *ssa.Block.
197+
entry llvm.BasicBlock
198+
199+
// exit is the LLVM basic block corresponding to the end of this *ssa.Block.
200+
// It will be different than entry if any of the block's instructions contain internal branches.
201+
exit llvm.BasicBlock
202+
203+
// tarjan holds state for applying Tarjan's strongly connected components algorithm to the CFG.
204+
// This is used by defer.go to determine whether to stack- or heap-allocate defer data.
205+
tarjan tarjanNode
206+
}
207+
195208
type deferBuiltin struct {
196209
callName string
197210
pos token.Pos
@@ -1220,14 +1233,29 @@ func (b *builder) createFunctionStart(intrinsic bool) {
12201233
// intrinsic (like an atomic operation). Create the entry block
12211234
// manually.
12221235
entryBlock = b.ctx.AddBasicBlock(b.llvmFn, "entry")
1236+
// Intrinsics may create internal branches (e.g. nil checks).
1237+
// They will attempt to access b.currentBlockInfo to update the exit block.
1238+
// Create some fake block info for them to access.
1239+
blockInfo := []blockInfo{
1240+
{
1241+
entry: entryBlock,
1242+
exit: entryBlock,
1243+
},
1244+
}
1245+
b.blockInfo = blockInfo
1246+
b.currentBlockInfo = &blockInfo[0]
12231247
} else {
1248+
blocks := b.fn.Blocks
1249+
blockInfo := make([]blockInfo, len(blocks))
12241250
for _, block := range b.fn.DomPreorder() {
1251+
info := &blockInfo[block.Index]
12251252
llvmBlock := b.ctx.AddBasicBlock(b.llvmFn, block.Comment)
1226-
b.blockEntries[block] = llvmBlock
1227-
b.blockExits[block] = llvmBlock
1253+
info.entry = llvmBlock
1254+
info.exit = llvmBlock
12281255
}
1256+
b.blockInfo = blockInfo
12291257
// Normal functions have an entry block.
1230-
entryBlock = b.blockEntries[b.fn.Blocks[0]]
1258+
entryBlock = blockInfo[0].entry
12311259
}
12321260
b.SetInsertPointAtEnd(entryBlock)
12331261

@@ -1323,8 +1351,9 @@ func (b *builder) createFunction() {
13231351
if b.DumpSSA {
13241352
fmt.Printf("%d: %s:\n", block.Index, block.Comment)
13251353
}
1326-
b.SetInsertPointAtEnd(b.blockEntries[block])
13271354
b.currentBlock = block
1355+
b.currentBlockInfo = &b.blockInfo[block.Index]
1356+
b.SetInsertPointAtEnd(b.currentBlockInfo.entry)
13281357
for _, instr := range block.Instrs {
13291358
if instr, ok := instr.(*ssa.DebugRef); ok {
13301359
if !b.Debug {
@@ -1384,7 +1413,7 @@ func (b *builder) createFunction() {
13841413
block := phi.ssa.Block()
13851414
for i, edge := range phi.ssa.Edges {
13861415
llvmVal := b.getValue(edge, getPos(phi.ssa))
1387-
llvmBlock := b.blockExits[block.Preds[i]]
1416+
llvmBlock := b.blockInfo[block.Preds[i].Index].exit
13881417
phi.llvm.AddIncoming([]llvm.Value{llvmVal}, []llvm.BasicBlock{llvmBlock})
13891418
}
13901419
}
@@ -1498,11 +1527,11 @@ func (b *builder) createInstruction(instr ssa.Instruction) {
14981527
case *ssa.If:
14991528
cond := b.getValue(instr.Cond, getPos(instr))
15001529
block := instr.Block()
1501-
blockThen := b.blockEntries[block.Succs[0]]
1502-
blockElse := b.blockEntries[block.Succs[1]]
1530+
blockThen := b.blockInfo[block.Succs[0].Index].entry
1531+
blockElse := b.blockInfo[block.Succs[1].Index].entry
15031532
b.CreateCondBr(cond, blockThen, blockElse)
15041533
case *ssa.Jump:
1505-
blockJump := b.blockEntries[instr.Block().Succs[0]]
1534+
blockJump := b.blockInfo[instr.Block().Succs[0].Index].entry
15061535
b.CreateBr(blockJump)
15071536
case *ssa.MapUpdate:
15081537
m := b.getValue(instr.Map, getPos(instr))

compiler/defer.go

Lines changed: 100 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func (b *builder) createLandingPad() {
100100

101101
// Continue at the 'recover' block, which returns to the parent in an
102102
// appropriate way.
103-
b.CreateBr(b.blockEntries[b.fn.Recover])
103+
b.CreateBr(b.blockInfo[b.fn.Recover.Index].entry)
104104
}
105105

106106
// Create a checkpoint (similar to setjmp). This emits inline assembly that
@@ -234,41 +234,108 @@ func (b *builder) createInvokeCheckpoint() {
234234
continueBB := b.insertBasicBlock("")
235235
b.CreateCondBr(isZero, continueBB, b.landingpad)
236236
b.SetInsertPointAtEnd(continueBB)
237-
b.blockExits[b.currentBlock] = continueBB
237+
b.currentBlockInfo.exit = continueBB
238238
}
239239

240-
// isInLoop checks if there is a path from a basic block to itself.
241-
func isInLoop(start *ssa.BasicBlock) bool {
242-
// Use a breadth-first search to scan backwards through the block graph.
243-
queue := []*ssa.BasicBlock{start}
244-
checked := map[*ssa.BasicBlock]struct{}{}
245-
246-
for len(queue) > 0 {
247-
// pop a block off of the queue
248-
block := queue[len(queue)-1]
249-
queue = queue[:len(queue)-1]
250-
251-
// Search through predecessors.
252-
// Searching backwards means that this is pretty fast when the block is close to the start of the function.
253-
// Defers are often placed near the start of the function.
254-
for _, pred := range block.Preds {
255-
if pred == start {
256-
// cycle found
257-
return true
258-
}
240+
// isInLoop checks if there is a path from the current block to itself.
241+
// Use Tarjan's strongly connected components algorithm to search for cycles.
242+
// A one-node SCC is a cycle iff there is an edge from the node to itself.
243+
// A multi-node SCC is always a cycle.
244+
// https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
245+
func (b *builder) isInLoop() bool {
246+
if b.currentBlockInfo.tarjan.lowLink == 0 {
247+
b.strongConnect(b.currentBlock)
248+
}
249+
return b.currentBlockInfo.tarjan.cyclic
250+
}
259251

260-
if _, ok := checked[pred]; ok {
261-
// block already checked
262-
continue
263-
}
252+
func (b *builder) strongConnect(block *ssa.BasicBlock) {
253+
// Assign a new index.
254+
// Indices start from 1 so that 0 can be used as a sentinel.
255+
assignedIndex := b.tarjanIndex + 1
256+
b.tarjanIndex = assignedIndex
257+
258+
// Apply the new index.
259+
blockIndex := block.Index
260+
node := &b.blockInfo[blockIndex].tarjan
261+
node.lowLink = assignedIndex
262+
263+
// Push the node onto the stack.
264+
node.onStack = true
265+
b.tarjanStack = append(b.tarjanStack, uint(blockIndex))
266+
267+
// Process the successors.
268+
for _, successor := range block.Succs {
269+
// Look up the successor's state.
270+
successorIndex := successor.Index
271+
if successorIndex == blockIndex {
272+
// Handle a self-cycle specially.
273+
node.cyclic = true
274+
continue
275+
}
276+
successorNode := &b.blockInfo[successorIndex].tarjan
264277

265-
// add to queue and checked map
266-
queue = append(queue, pred)
267-
checked[pred] = struct{}{}
278+
switch {
279+
case successorNode.lowLink == 0:
280+
// This node has not yet been visisted.
281+
b.strongConnect(successor)
282+
283+
case !successorNode.onStack:
284+
// This node has been visited, but is in a different SCC.
285+
// Ignore it, and do not update lowLink.
286+
continue
287+
}
288+
289+
// Update the lowLink index.
290+
// This always uses the min-of-lowlink instead of using index in the on-stack case.
291+
// This is done for two reasons:
292+
// 1. The lowLink update can be shared between the new-node and on-stack cases.
293+
// 2. The assigned index does not need to be saved - it is only needed for root node detection.
294+
if successorNode.lowLink < node.lowLink {
295+
node.lowLink = successorNode.lowLink
268296
}
269297
}
270298

271-
return false
299+
if node.lowLink == assignedIndex {
300+
// This is a root node.
301+
// Pop the SCC off the stack.
302+
stack := b.tarjanStack
303+
top := stack[len(stack)-1]
304+
stack = stack[:len(stack)-1]
305+
blocks := b.blockInfo
306+
topNode := &blocks[top].tarjan
307+
topNode.onStack = false
308+
309+
if top != uint(blockIndex) {
310+
// The root node is not the only node in the SCC.
311+
// Mark all nodes in this SCC as cyclic.
312+
topNode.cyclic = true
313+
for top != uint(blockIndex) {
314+
top = stack[len(stack)-1]
315+
stack = stack[:len(stack)-1]
316+
topNode = &blocks[top].tarjan
317+
topNode.onStack = false
318+
topNode.cyclic = true
319+
}
320+
}
321+
322+
b.tarjanStack = stack
323+
}
324+
}
325+
326+
// tarjanNode holds per-block state for isInLoop and strongConnect.
327+
type tarjanNode struct {
328+
// lowLink is the index of the first visited node that is reachable from this block.
329+
// The lowLink indices are assigned by the SCC search, and do not correspond to b.Index.
330+
// A lowLink of 0 is used as a sentinel to mark a node which has not yet been visited.
331+
lowLink uint
332+
333+
// onStack tracks whether this node is currently on the SCC search stack.
334+
onStack bool
335+
336+
// cyclic indicates whether this block is in a loop.
337+
// If lowLink is 0, strongConnect must be called before reading this field.
338+
cyclic bool
272339
}
273340

274341
// createDefer emits a single defer instruction, to be run when this function
@@ -410,7 +477,10 @@ func (b *builder) createDefer(instr *ssa.Defer) {
410477

411478
// Put this struct in an allocation.
412479
var alloca llvm.Value
413-
if !isInLoop(instr.Block()) {
480+
if instr.Block() != b.currentBlock {
481+
panic("block mismatch")
482+
}
483+
if !b.isInLoop() {
414484
// This can safely use a stack allocation.
415485
alloca = llvmutil.CreateEntryBlockAlloca(b.Builder, deferredCallType, "defer.alloca")
416486
} else {

compiler/interface.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ func (b *builder) createTypeAssert(expr *ssa.TypeAssert) llvm.Value {
737737
prevBlock := b.GetInsertBlock()
738738
okBlock := b.insertBasicBlock("typeassert.ok")
739739
nextBlock := b.insertBasicBlock("typeassert.next")
740-
b.blockExits[b.currentBlock] = nextBlock // adjust outgoing block for phi nodes
740+
b.currentBlockInfo.exit = nextBlock // adjust outgoing block for phi nodes
741741
b.CreateCondBr(commaOk, okBlock, nextBlock)
742742

743743
// Retrieve the value from the interface if the type assert was

0 commit comments

Comments
 (0)