Skip to content

Commit f0747ba

Browse files
author
William Grant
committed
Allow for partial module loads in compiler cache.
Previously we would always attempt to link and validate all global variables when loading a module from the cache. This caused linking errors, or validation errors, or deserialization errors, and meant we needed the mark_invalid mechanism for handling modules with outdated global variables. Here we add double-serialised global variables, and only deserialize,link&validate the subset required for the function required (and its dependencies). This requires the cache to store a function and global_var dependency graph. Also add utility methods for GlobalVariableDefinition.
1 parent 65b622c commit f0747ba

10 files changed

+270
-167
lines changed

typed_python/compiler/binary_shared_object.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626

2727

2828
class LoadedBinarySharedObject(LoadedModule):
29-
def __init__(self, binarySharedObject, diskPath, functionPointers, globalVariableDefinitions):
30-
super().__init__(functionPointers, globalVariableDefinitions)
29+
def __init__(self, binarySharedObject, diskPath, functionPointers, serializedGlobalVariableDefinitions):
30+
super().__init__(functionPointers, serializedGlobalVariableDefinitions)
3131

3232
self.binarySharedObject = binarySharedObject
3333
self.diskPath = diskPath
@@ -36,30 +36,32 @@ def __init__(self, binarySharedObject, diskPath, functionPointers, globalVariabl
3636
class BinarySharedObject:
3737
"""Models a shared object library (.so) loadable on linux systems."""
3838

39-
def __init__(self, binaryForm, functionTypes, globalVariableDefinitions):
39+
def __init__(self, binaryForm, functionTypes, serializedGlobalVariableDefinitions, globalDependencies):
4040
"""
4141
Args:
42-
binaryForm - a bytes object containing the actual compiled code for the module
43-
globalVariableDefinitions - a map from name to GlobalVariableDefinition
42+
binaryForm: a bytes object containing the actual compiled code for the module
43+
serializedGlobalVariableDefinitions: a map from name to GlobalVariableDefinition
44+
globalDependencies: a dict from function linkname to the list of global variables it depends on
4445
"""
4546
self.binaryForm = binaryForm
4647
self.functionTypes = functionTypes
47-
self.globalVariableDefinitions = globalVariableDefinitions
48+
self.serializedGlobalVariableDefinitions = serializedGlobalVariableDefinitions
49+
self.globalDependencies = globalDependencies
4850
self.hash = sha_hash(binaryForm)
4951

5052
@property
5153
def definedSymbols(self):
5254
return self.functionTypes.keys()
5355

5456
@staticmethod
55-
def fromDisk(path, globalVariableDefinitions, functionNameToType):
57+
def fromDisk(path, serializedGlobalVariableDefinitions, functionNameToType, globalDependencies):
5658
with open(path, "rb") as f:
5759
binaryForm = f.read()
5860

59-
return BinarySharedObject(binaryForm, functionNameToType, globalVariableDefinitions)
61+
return BinarySharedObject(binaryForm, functionNameToType, serializedGlobalVariableDefinitions, globalDependencies)
6062

6163
@staticmethod
62-
def fromModule(module, globalVariableDefinitions, functionNameToType):
64+
def fromModule(module, serializedGlobalVariableDefinitions, functionNameToType, globalDependencies):
6365
target_triple = llvm.get_process_triple()
6466
target = llvm.Target.from_triple(target_triple)
6567
target_machine_shared_object = target.create_target_machine(reloc='pic', codemodel='default')
@@ -80,7 +82,7 @@ def fromModule(module, globalVariableDefinitions, functionNameToType):
8082
)
8183

8284
with open(os.path.join(tf, "module.so"), "rb") as so_file:
83-
return BinarySharedObject(so_file.read(), functionNameToType, globalVariableDefinitions)
85+
return BinarySharedObject(so_file.read(), functionNameToType, serializedGlobalVariableDefinitions, globalDependencies)
8486

8587
def load(self, storageDir):
8688
"""Instantiate this .so in temporary storage and return a dict from symbol -> integer function pointer"""
@@ -127,8 +129,7 @@ def loadFromPath(self, modulePath):
127129
self,
128130
modulePath,
129131
functionPointers,
130-
self.globalVariableDefinitions
132+
self.serializedGlobalVariableDefinitions
131133
)
132-
loadedModule.linkGlobalVariables()
133134

134135
return loadedModule

typed_python/compiler/compiler_cache.py

Lines changed: 106 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
import os
1616
import uuid
1717
import shutil
18-
from typed_python.compiler.loaded_module import LoadedModule
19-
from typed_python.compiler.binary_shared_object import BinarySharedObject
2018

19+
from typing import Optional, List
20+
21+
from typed_python.compiler.binary_shared_object import LoadedBinarySharedObject, BinarySharedObject
22+
from typed_python.compiler.directed_graph import DirectedGraph
23+
from typed_python.compiler.typed_call_target import TypedCallTarget
2124
from typed_python.SerializationContext import SerializationContext
2225
from typed_python import Dict, ListOf
2326

@@ -52,146 +55,173 @@ def __init__(self, cacheDir):
5255

5356
ensureDirExists(cacheDir)
5457

55-
self.loadedModules = Dict(str, LoadedModule)()
58+
self.loadedBinarySharedObjects = Dict(str, LoadedBinarySharedObject)()
5659
self.nameToModuleHash = Dict(str, str)()
5760

58-
self.modulesMarkedValid = set()
59-
self.modulesMarkedInvalid = set()
61+
self.moduleManifestsLoaded = set()
6062

6163
for moduleHash in os.listdir(self.cacheDir):
6264
if len(moduleHash) == 40:
6365
self.loadNameManifestFromStoredModuleByHash(moduleHash)
6466

65-
self.targetsLoaded = {}
67+
# the set of functions with an associated module in loadedBinarySharedObjects
68+
self.targetsLoaded: Dict[str, TypedCallTarget] = {}
6669

67-
def hasSymbol(self, linkName):
68-
return linkName in self.nameToModuleHash
70+
# the set of functions with linked and validated globals (i.e. ready to be run).
71+
self.targetsValidated = set()
6972

70-
def getTarget(self, linkName):
71-
assert self.hasSymbol(linkName)
73+
self.function_dependency_graph = DirectedGraph()
74+
# dict from function linkname to list of global names (should be llvm keys in serialisedGlobalDefinitions)
75+
self.global_dependencies = Dict(str, ListOf(str))()
7276

73-
self.loadForSymbol(linkName)
77+
def hasSymbol(self, linkName: str) -> bool:
78+
"""NB this will return True even if the linkName is ultimately unretrievable."""
79+
return linkName in self.nameToModuleHash
7480

81+
def getTarget(self, linkName: str) -> TypedCallTarget:
82+
if not self.hasSymbol(linkName):
83+
raise ValueError(f'symbol not found for linkName {linkName}')
84+
self.loadForSymbol(linkName)
7585
return self.targetsLoaded[linkName]
7686

77-
def markModuleHashInvalid(self, hashstr):
78-
with open(os.path.join(self.cacheDir, hashstr, "marked_invalid"), "w"):
79-
pass
87+
def dependencies(self, linkName: str) -> Optional[List[str]]:
88+
"""Returns all the function names that `linkName` depends on"""
89+
return list(self.function_dependency_graph.outgoing(linkName))
8090

81-
def loadForSymbol(self, linkName):
91+
def loadForSymbol(self, linkName: str) -> None:
92+
"""Loads the whole module, and any submodules, into LoadedBinarySharedObjects"""
8293
moduleHash = self.nameToModuleHash[linkName]
8394

8495
self.loadModuleByHash(moduleHash)
8596

86-
def loadModuleByHash(self, moduleHash):
97+
if linkName not in self.targetsValidated:
98+
dependantFuncs = self.dependencies(linkName) + [linkName]
99+
globalsToLink = {} # dict from modulehash to list of globals.
100+
for funcName in dependantFuncs:
101+
if funcName not in self.targetsValidated:
102+
funcModuleHash = self.nameToModuleHash[funcName]
103+
# append to the list of globals to link for a given module. TODO: optimise this, don't double-link.
104+
globalsToLink[funcModuleHash] = globalsToLink.get(funcModuleHash, []) + self.global_dependencies.get(funcName, [])
105+
106+
for moduleHash, globs in globalsToLink.items(): # this works because loadModuleByHash loads submodules too.
107+
if globs:
108+
definitionsToLink = {x: self.loadedBinarySharedObjects[moduleHash].serializedGlobalVariableDefinitions[x]
109+
for x in globs
110+
}
111+
self.loadedBinarySharedObjects[moduleHash].linkGlobalVariables(definitionsToLink)
112+
if not self.loadedBinarySharedObjects[moduleHash].validateGlobalVariables(definitionsToLink):
113+
raise RuntimeError('failed to validate globals when loading:', linkName)
114+
115+
self.targetsValidated.update(dependantFuncs)
116+
117+
def loadModuleByHash(self, moduleHash: str) -> None:
87118
"""Load a module by name.
88119
89120
As we load, place all the newly imported typed call targets into
90121
'nameToTypedCallTarget' so that the rest of the system knows what functions
91122
have been uncovered.
92123
"""
93-
if moduleHash in self.loadedModules:
94-
return True
124+
if moduleHash in self.loadedBinarySharedObjects:
125+
return
95126

96127
targetDir = os.path.join(self.cacheDir, moduleHash)
97128

98-
try:
99-
with open(os.path.join(targetDir, "type_manifest.dat"), "rb") as f:
100-
callTargets = SerializationContext().deserialize(f.read())
129+
# TODO (Will) - store these names as module consts, use one .dat only
130+
with open(os.path.join(targetDir, "type_manifest.dat"), "rb") as f:
131+
callTargets = SerializationContext().deserialize(f.read())
101132

102-
with open(os.path.join(targetDir, "globals_manifest.dat"), "rb") as f:
103-
globalVarDefs = SerializationContext().deserialize(f.read())
133+
with open(os.path.join(targetDir, "globals_manifest.dat"), "rb") as f:
134+
serializedGlobalVarDefs = SerializationContext().deserialize(f.read())
104135

105-
with open(os.path.join(targetDir, "native_type_manifest.dat"), "rb") as f:
106-
functionNameToNativeType = SerializationContext().deserialize(f.read())
136+
with open(os.path.join(targetDir, "native_type_manifest.dat"), "rb") as f:
137+
functionNameToNativeType = SerializationContext().deserialize(f.read())
138+
139+
with open(os.path.join(targetDir, "submodules.dat"), "rb") as f:
140+
submodules = SerializationContext().deserialize(f.read(), ListOf(str))
107141

108-
with open(os.path.join(targetDir, "submodules.dat"), "rb") as f:
109-
submodules = SerializationContext().deserialize(f.read(), ListOf(str))
110-
except Exception:
111-
self.markModuleHashInvalid(moduleHash)
112-
return False
142+
with open(os.path.join(targetDir, "function_dependencies.dat"), "rb") as f:
143+
dependency_edgelist = SerializationContext().deserialize(f.read())
113144

114-
if not LoadedModule.validateGlobalVariables(globalVarDefs):
115-
self.markModuleHashInvalid(moduleHash)
116-
return False
145+
with open(os.path.join(targetDir, "global_dependencies.dat"), "rb") as f:
146+
globalDependencies = SerializationContext().deserialize(f.read())
117147

118148
# load the submodules first
119149
for submodule in submodules:
120-
if not self.loadModuleByHash(submodule):
121-
return False
150+
self.loadModuleByHash(submodule)
122151

123152
modulePath = os.path.join(targetDir, "module.so")
124153

125154
loaded = BinarySharedObject.fromDisk(
126155
modulePath,
127-
globalVarDefs,
128-
functionNameToNativeType
156+
serializedGlobalVarDefs,
157+
functionNameToNativeType,
158+
globalDependencies
159+
129160
).loadFromPath(modulePath)
130161

131-
self.loadedModules[moduleHash] = loaded
162+
self.loadedBinarySharedObjects[moduleHash] = loaded
132163

133164
self.targetsLoaded.update(callTargets)
134165

135-
return True
166+
assert not any(key in self.global_dependencies for key in globalDependencies) # should only happen if there's a hash collision.
167+
self.global_dependencies.update(globalDependencies)
168+
169+
# update the cache's dependency graph with our new edges.
170+
for function_name, dependant_function_name in dependency_edgelist:
171+
self.function_dependency_graph.addEdge(source=function_name, dest=dependant_function_name)
136172

137-
def addModule(self, binarySharedObject, nameToTypedCallTarget, linkDependencies):
173+
def addModule(self, binarySharedObject, nameToTypedCallTarget, linkDependencies, dependencyEdgelist):
138174
"""Add new code to the compiler cache.
139175
140176
Args:
141-
binarySharedObject - a BinarySharedObject containing the actual assembler
142-
we've compiled
143-
nameToTypedCallTarget - a dict from linkname to TypedCallTarget telling us
144-
the formal python types for all the objects
145-
linkDependencies - a set of linknames we depend on directly.
177+
binarySharedObject: a BinarySharedObject containing the actual assembler
178+
we've compiled.
179+
nameToTypedCallTarget: a dict from linkname to TypedCallTarget telling us
180+
the formal python types for all the objects.
181+
linkDependencies: a set of linknames we depend on directly.
182+
dependencyEdgelist (list): a list of source, dest pairs giving the set of dependency graph for the
183+
module.
146184
"""
147185
dependentHashes = set()
148186

149187
for name in linkDependencies:
150188
dependentHashes.add(self.nameToModuleHash[name])
151189

152-
path, hashToUse = self.writeModuleToDisk(binarySharedObject, nameToTypedCallTarget, dependentHashes)
190+
path, hashToUse = self.writeModuleToDisk(binarySharedObject, nameToTypedCallTarget, dependentHashes, dependencyEdgelist)
153191

154-
self.loadedModules[hashToUse] = (
192+
self.loadedBinarySharedObjects[hashToUse] = (
155193
binarySharedObject.loadFromPath(os.path.join(path, "module.so"))
156194
)
157195

158196
for n in binarySharedObject.definedSymbols:
159197
self.nameToModuleHash[n] = hashToUse
160198

161-
def loadNameManifestFromStoredModuleByHash(self, moduleHash):
162-
if moduleHash in self.modulesMarkedValid:
163-
return True
164-
165-
targetDir = os.path.join(self.cacheDir, moduleHash)
199+
# link & validate all globals for the new module
200+
self.loadedBinarySharedObjects[hashToUse].linkGlobalVariables()
201+
if not self.loadedBinarySharedObjects[hashToUse].validateGlobalVariables(
202+
self.loadedBinarySharedObjects[hashToUse].serializedGlobalVariableDefinitions):
203+
raise RuntimeError('failed to validate globals in new module:', hashToUse)
166204

167-
# ignore 'marked invalid'
168-
if os.path.exists(os.path.join(targetDir, "marked_invalid")):
169-
# just bail - don't try to read it now
205+
def loadNameManifestFromStoredModuleByHash(self, moduleHash) -> None:
206+
if moduleHash in self.moduleManifestsLoaded:
207+
return
170208

171-
# for the moment, we don't try to clean up the cache, because
172-
# we can't be sure that some process is not still reading the
173-
# old files.
174-
self.modulesMarkedInvalid.add(moduleHash)
175-
return False
209+
targetDir = os.path.join(self.cacheDir, moduleHash)
176210

177211
with open(os.path.join(targetDir, "submodules.dat"), "rb") as f:
178212
submodules = SerializationContext().deserialize(f.read(), ListOf(str))
179213

180214
for subHash in submodules:
181-
if not self.loadNameManifestFromStoredModuleByHash(subHash):
182-
self.markModuleHashInvalid(subHash)
183-
return False
215+
self.loadNameManifestFromStoredModuleByHash(subHash)
184216

185217
with open(os.path.join(targetDir, "name_manifest.dat"), "rb") as f:
186218
self.nameToModuleHash.update(
187219
SerializationContext().deserialize(f.read(), Dict(str, str))
188220
)
189221

190-
self.modulesMarkedValid.add(moduleHash)
191-
192-
return True
222+
self.moduleManifestsLoaded.add(moduleHash)
193223

194-
def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodules):
224+
def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodules, dependencyEdgelist):
195225
"""Write out a disk representation of this module.
196226
197227
This includes writing both the shared object, a manifest of the function names
@@ -244,11 +274,17 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule
244274

245275
# write the type manifest
246276
with open(os.path.join(tempTargetDir, "globals_manifest.dat"), "wb") as f:
247-
f.write(SerializationContext().serialize(binarySharedObject.globalVariableDefinitions))
277+
f.write(SerializationContext().serialize(binarySharedObject.serializedGlobalVariableDefinitions))
248278

249279
with open(os.path.join(tempTargetDir, "submodules.dat"), "wb") as f:
250280
f.write(SerializationContext().serialize(ListOf(str)(submodules), ListOf(str)))
251281

282+
with open(os.path.join(tempTargetDir, "function_dependencies.dat"), "wb") as f:
283+
f.write(SerializationContext().serialize(dependencyEdgelist)) # might need a listof
284+
285+
with open(os.path.join(tempTargetDir, "global_dependencies.dat"), "wb") as f:
286+
f.write(SerializationContext().serialize(binarySharedObject.globalDependencies))
287+
252288
try:
253289
os.rename(tempTargetDir, targetDir)
254290
except IOError:
@@ -264,7 +300,7 @@ def function_pointer_by_name(self, linkName):
264300
if moduleHash is None:
265301
raise Exception("Can't find a module for " + linkName)
266302

267-
if moduleHash not in self.loadedModules:
303+
if moduleHash not in self.loadedBinarySharedObjects:
268304
self.loadForSymbol(linkName)
269305

270-
return self.loadedModules[moduleHash].functionPointers[linkName]
306+
return self.loadedBinarySharedObjects[moduleHash].functionPointers[linkName]

0 commit comments

Comments
 (0)