@@ -49,77 +49,88 @@ class CompilerCache:
4949 when we first boot up, which could be slow. We could improve this substantially
5050 by making it possible to determine if a given function is in the cache by organizing
5151 the manifests by, say, function name.
52+
53+ Due to the potential for race conditions, we must distinguish between the following:
54+ func_name - The identifier for the function, based on its identity hash.
55+ link_name - The identifier for the specific realization of that function, which lives in a specific
56+ cache module.
5257 """
5358 def __init__ (self , cacheDir ):
5459 self .cacheDir = cacheDir
5560
5661 ensureDirExists (cacheDir )
5762
5863 self .loadedBinarySharedObjects = Dict (str , LoadedBinarySharedObject )()
59- self .nameToModuleHash = Dict (str , str )()
60-
64+ self .link_name_to_module_hash = Dict (str , str )()
6165 self .moduleManifestsLoaded = set ()
62-
66+ # link_names with an associated module in loadedBinarySharedObjects
67+ self .targetsLoaded : Dict [str , TypedCallTarget ] = {}
68+ # the set of link_names for functions with linked and validated globals (i.e. ready to be run).
69+ self .targetsValidated = set ()
70+ # link_name -> link_name
71+ self .function_dependency_graph = DirectedGraph ()
72+ # dict from link_name to list of global names (should be llvm keys in serialisedGlobalDefinitions)
73+ self .global_dependencies = Dict (str , ListOf (str ))()
74+ self .func_name_to_link_names = Dict (str , ListOf (str ))()
6375 for moduleHash in os .listdir (self .cacheDir ):
6476 if len (moduleHash ) == 40 :
6577 self .loadNameManifestFromStoredModuleByHash (moduleHash )
6678
67- # the set of functions with an associated module in loadedBinarySharedObjects
68- self . targetsLoaded : Dict [ str , TypedCallTarget ] = {}
79+ def hasSymbol ( self , func_name : str ) -> bool :
80+ """Returns true if there are any versions of `func_name` in the cache.
6981
70- # the set of functions with linked and validated globals (i.e. ready to be run).
71- self .targetsValidated = set ()
82+ There may be multiple copies in different modules with different link_names.
83+ """
84+ return any (link_name in self .link_name_to_module_hash for link_name in self .func_name_to_link_names .get (func_name , []))
7285
73- self .function_dependency_graph = DirectedGraph ()
74- # dict from function linkname to list of global names (should be llvm keys in serialisedGlobalDefinitions)
75- self .global_dependencies = Dict (str , ListOf (str ))()
86+ def getTarget (self , func_name : str ) -> TypedCallTarget :
87+ if not self .hasSymbol (func_name ):
88+ raise ValueError (f'symbol not found for func_name { func_name } ' )
89+ link_name = self ._select_link_name (func_name )
90+ self .loadForSymbol (link_name )
91+ return self .targetsLoaded [link_name ]
7692
77- def hasSymbol (self , linkName : str ) -> bool :
78- """NB this will return True even if the linkName is ultimately unretrievable."""
79- return linkName in self .nameToModuleHash
93+ def _generate_link_name (self , func_name : str , module_hash : str ) -> str :
94+ return func_name + "." + module_hash
8095
81- def getTarget (self , linkName : str ) -> TypedCallTarget :
82- if not self .hasSymbol (linkName ):
83- raise ValueError (f'symbol not found for linkName { linkName } ' )
84- self .loadForSymbol (linkName )
85- return self .targetsLoaded [linkName ]
96+ def _select_link_name (self , func_name ) -> str :
97+ """choose a link name for a given func name.
8698
87- def dependencies (self , linkName : str ) -> Optional [List [str ]]:
88- """Returns all the function names that `linkName` depends on"""
89- return list (self .function_dependency_graph .outgoing (linkName ))
99+ Currently we just choose the first available option.
100+ Throws a KeyError if func_name isn't in the cache.
101+ """
102+ link_name_candidates = self .func_name_to_link_names [func_name ]
103+ return link_name_candidates [0 ]
104+
105+ def dependencies (self , link_name : str ) -> Optional [List [str ]]:
106+ """Returns all the function names that `link_name` depends on"""
107+ return list (self .function_dependency_graph .outgoing (link_name ))
90108
91109 def loadForSymbol (self , linkName : str ) -> None :
92- """Loads the whole module, and any submodules , into LoadedBinarySharedObjects"""
93- moduleHash = self .nameToModuleHash [linkName ]
110+ """Loads the whole module, and any dependant modules , into LoadedBinarySharedObjects"""
111+ moduleHash = self .link_name_to_module_hash [linkName ]
94112
95113 self .loadModuleByHash (moduleHash )
96114
97115 if linkName not in self .targetsValidated :
98- dependantFuncs = self .dependencies (linkName ) + [linkName ]
99- globalsToLink = {} # dict from modulehash to list of globals.
100- for funcName in dependantFuncs :
101- if funcName not in self .targetsValidated :
102- funcModuleHash = self .nameToModuleHash [funcName ]
103- # append to the list of globals to link for a given module. TODO: optimise this, don't double-link.
104- globalsToLink [funcModuleHash ] = globalsToLink .get (funcModuleHash , []) + self .global_dependencies .get (funcName , [])
105-
106- for moduleHash , globs in globalsToLink .items (): # this works because loadModuleByHash loads submodules too.
107- if globs :
108- definitionsToLink = {x : self .loadedBinarySharedObjects [moduleHash ].serializedGlobalVariableDefinitions [x ]
109- for x in globs
110- }
111- self .loadedBinarySharedObjects [moduleHash ].linkGlobalVariables (definitionsToLink )
112- if not self .loadedBinarySharedObjects [moduleHash ].validateGlobalVariables (definitionsToLink ):
113- raise RuntimeError ('failed to validate globals when loading:' , linkName )
114-
115- self .targetsValidated .update (dependantFuncs )
116+ self .targetsValidated .add (linkName )
117+ for dependant_func in self .dependencies (linkName ):
118+ self .loadForSymbol (dependant_func )
119+
120+ globalsToLink = self .global_dependencies .get (linkName , [])
121+ if globalsToLink :
122+ definitionsToLink = {x : self .loadedBinarySharedObjects [moduleHash ].serializedGlobalVariableDefinitions [x ]
123+ for x in globalsToLink
124+ }
125+ self .loadedBinarySharedObjects [moduleHash ].linkGlobalVariables (definitionsToLink )
126+ if not self .loadedBinarySharedObjects [moduleHash ].validateGlobalVariables (definitionsToLink ):
127+ raise RuntimeError ('failed to validate globals when loading:' , linkName )
116128
117129 def loadModuleByHash (self , moduleHash : str ) -> None :
118130 """Load a module by name.
119131
120- As we load, place all the newly imported typed call targets into
121- 'nameToTypedCallTarget' so that the rest of the system knows what functions
122- have been uncovered.
132+ Add the module contents to targetsLoaded, generate a LoadedBinarySharedObject,
133+ and update the function and global dependency graphs.
123134 """
124135 if moduleHash in self .loadedBinarySharedObjects :
125136 return
@@ -128,6 +139,7 @@ def loadModuleByHash(self, moduleHash: str) -> None:
128139
129140 # TODO (Will) - store these names as module consts, use one .dat only
130141 with open (os .path .join (targetDir , "type_manifest.dat" ), "rb" ) as f :
142+ # func_name -> typedcalltarget
131143 callTargets = SerializationContext ().deserialize (f .read ())
132144
133145 with open (os .path .join (targetDir , "globals_manifest.dat" ), "rb" ) as f :
@@ -156,45 +168,68 @@ def loadModuleByHash(self, moduleHash: str) -> None:
156168 serializedGlobalVarDefs ,
157169 functionNameToNativeType ,
158170 globalDependencies
159-
160171 ).loadFromPath (modulePath )
161172
162173 self .loadedBinarySharedObjects [moduleHash ] = loaded
163174
164- self .targetsLoaded .update (callTargets )
175+ for func_name , callTarget in callTargets .items ():
176+ link_name = self ._generate_link_name (func_name , moduleHash )
177+ assert link_name not in self .targetsLoaded
178+ self .targetsLoaded [link_name ] = callTarget
165179
166- assert not any (key in self .global_dependencies for key in globalDependencies ) # should only happen if there's a hash collision.
167- self .global_dependencies .update (globalDependencies )
180+ link_name_global_dependencies = {self ._generate_link_name (x , moduleHash ): y for x , y in globalDependencies .items ()}
168181
182+ assert not any (key in self .global_dependencies for key in link_name_global_dependencies )
183+
184+ self .global_dependencies .update (link_name_global_dependencies )
169185 # update the cache's dependency graph with our new edges.
170186 for function_name , dependant_function_name in dependency_edgelist :
171187 self .function_dependency_graph .addEdge (source = function_name , dest = dependant_function_name )
172188
173189 def addModule (self , binarySharedObject , nameToTypedCallTarget , linkDependencies , dependencyEdgelist ):
174190 """Add new code to the compiler cache.
175191
192+
176193 Args:
177194 binarySharedObject: a BinarySharedObject containing the actual assembler
178195 we've compiled.
179- nameToTypedCallTarget: a dict from linkname to TypedCallTarget telling us
196+ nameToTypedCallTarget: a dict from func_name to TypedCallTarget telling us
180197 the formal python types for all the objects.
181- linkDependencies: a set of linknames we depend on directly.
198+ linkDependencies: a set of func_names we depend on directly. (this becomes submodules)
182199 dependencyEdgelist (list): a list of source, dest pairs giving the set of dependency graph for the
183200 module.
201+
202+ TODO (Will): the notion of submodules/linkDependencies can be refactored out.
184203 """
185- dependentHashes = set ()
186204
205+ hashToUse = SerializationContext ().sha_hash (str (uuid .uuid4 ())).hexdigest
206+
207+ # the linkDependencies and dependencyEdgelist are in terms of func_name.
208+ dependentHashes = set ()
187209 for name in linkDependencies :
188- dependentHashes .add (self .nameToModuleHash [name ])
210+ link_name = self ._select_link_name (name )
211+ dependentHashes .add (self .link_name_to_module_hash [link_name ])
212+
213+ link_name_dependency_edgelist = []
214+ for source , dest in dependencyEdgelist :
215+ assert source in binarySharedObject .definedSymbols
216+ source_link_name = self ._generate_link_name (source , hashToUse )
217+ if dest in binarySharedObject .definedSymbols :
218+ dest_link_name = self ._generate_link_name (dest , hashToUse )
219+ else :
220+ dest_link_name = self ._select_link_name (dest )
221+ link_name_dependency_edgelist .append ([source_link_name , dest_link_name ])
189222
190- path , hashToUse = self .writeModuleToDisk (binarySharedObject , nameToTypedCallTarget , dependentHashes , dependencyEdgelist )
223+ path = self .writeModuleToDisk (binarySharedObject , hashToUse , nameToTypedCallTarget , dependentHashes , link_name_dependency_edgelist )
191224
192225 self .loadedBinarySharedObjects [hashToUse ] = (
193226 binarySharedObject .loadFromPath (os .path .join (path , "module.so" ))
194227 )
195228
196- for n in binarySharedObject .definedSymbols :
197- self .nameToModuleHash [n ] = hashToUse
229+ for func_name in binarySharedObject .definedSymbols :
230+ link_name = self ._generate_link_name (func_name , hashToUse )
231+ self .link_name_to_module_hash [link_name ] = hashToUse
232+ self .func_name_to_link_names .setdefault (func_name , []).append (link_name )
198233
199234 # link & validate all globals for the new module
200235 self .loadedBinarySharedObjects [hashToUse ].linkGlobalVariables ()
@@ -208,20 +243,18 @@ def loadNameManifestFromStoredModuleByHash(self, moduleHash) -> None:
208243
209244 targetDir = os .path .join (self .cacheDir , moduleHash )
210245
211- with open (os .path .join (targetDir , "submodules.dat" ), "rb" ) as f :
212- submodules = SerializationContext ().deserialize (f .read (), ListOf (str ))
213-
214- for subHash in submodules :
215- self .loadNameManifestFromStoredModuleByHash (subHash )
216-
246+ # TODO (Will) the name_manifest module_hash is the same throughout so this doesn't need to be a dict.
217247 with open (os .path .join (targetDir , "name_manifest.dat" ), "rb" ) as f :
218- self .nameToModuleHash .update (
219- SerializationContext ().deserialize (f .read (), Dict (str , str ))
220- )
248+ func_name_to_module_hash = SerializationContext ().deserialize (f .read (), Dict (str , str ))
249+
250+ for func_name , module_hash in func_name_to_module_hash .items ():
251+ link_name = self ._generate_link_name (func_name , module_hash )
252+ self .func_name_to_link_names .setdefault (func_name , []).append (link_name )
253+ self .link_name_to_module_hash [link_name ] = module_hash
221254
222255 self .moduleManifestsLoaded .add (moduleHash )
223256
224- def writeModuleToDisk (self , binarySharedObject , nameToTypedCallTarget , submodules , dependencyEdgelist ):
257+ def writeModuleToDisk (self , binarySharedObject , hashToUse , nameToTypedCallTarget , submodules , dependencyEdgelist ):
225258 """Write out a disk representation of this module.
226259
227260 This includes writing both the shared object, a manifest of the function names
@@ -235,7 +268,6 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule
235268 to interact with the compiler cache simultaneously without relying on
236269 individual file-level locking.
237270 """
238- hashToUse = SerializationContext ().sha_hash (str (uuid .uuid4 ())).hexdigest
239271
240272 targetDir = os .path .join (
241273 self .cacheDir ,
@@ -264,23 +296,20 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule
264296 for sourceName in manifest :
265297 f .write (sourceName + "\n " )
266298
267- # write the type manifest
268299 with open (os .path .join (tempTargetDir , "type_manifest.dat" ), "wb" ) as f :
269300 f .write (SerializationContext ().serialize (nameToTypedCallTarget ))
270301
271- # write the nativetype manifest
272302 with open (os .path .join (tempTargetDir , "native_type_manifest.dat" ), "wb" ) as f :
273303 f .write (SerializationContext ().serialize (binarySharedObject .functionTypes ))
274304
275- # write the type manifest
276305 with open (os .path .join (tempTargetDir , "globals_manifest.dat" ), "wb" ) as f :
277306 f .write (SerializationContext ().serialize (binarySharedObject .serializedGlobalVariableDefinitions ))
278307
279308 with open (os .path .join (tempTargetDir , "submodules.dat" ), "wb" ) as f :
280309 f .write (SerializationContext ().serialize (ListOf (str )(submodules ), ListOf (str )))
281310
282311 with open (os .path .join (tempTargetDir , "function_dependencies.dat" ), "wb" ) as f :
283- f .write (SerializationContext ().serialize (dependencyEdgelist )) # might need a listof
312+ f .write (SerializationContext ().serialize (dependencyEdgelist ))
284313
285314 with open (os .path .join (tempTargetDir , "global_dependencies.dat" ), "wb" ) as f :
286315 f .write (SerializationContext ().serialize (binarySharedObject .globalDependencies ))
@@ -293,14 +322,15 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule
293322 else :
294323 shutil .rmtree (tempTargetDir )
295324
296- return targetDir , hashToUse
325+ return targetDir
297326
298- def function_pointer_by_name (self , linkName ):
299- moduleHash = self .nameToModuleHash .get (linkName )
327+ def function_pointer_by_name (self , func_name ):
328+ linkName = self ._select_link_name (func_name )
329+ moduleHash = self .link_name_to_module_hash .get (linkName )
300330 if moduleHash is None :
301331 raise Exception ("Can't find a module for " + linkName )
302332
303333 if moduleHash not in self .loadedBinarySharedObjects :
304334 self .loadForSymbol (linkName )
305335
306- return self .loadedBinarySharedObjects [moduleHash ].functionPointers [linkName ]
336+ return self .loadedBinarySharedObjects [moduleHash ].functionPointers [func_name ]
0 commit comments