1515import os
1616import uuid
1717import shutil
18- from typed_python .compiler .loaded_module import LoadedModule
19- from typed_python .compiler .binary_shared_object import BinarySharedObject
2018
19+ from typing import Optional , List
20+
21+ from typed_python .compiler .binary_shared_object import LoadedBinarySharedObject , BinarySharedObject
22+ from typed_python .compiler .directed_graph import DirectedGraph
23+ from typed_python .compiler .typed_call_target import TypedCallTarget
2124from typed_python .SerializationContext import SerializationContext
2225from typed_python import Dict , ListOf
2326
@@ -52,146 +55,173 @@ def __init__(self, cacheDir):
5255
5356 ensureDirExists (cacheDir )
5457
55- self .loadedModules = Dict (str , LoadedModule )()
58+ self .loadedBinarySharedObjects = Dict (str , LoadedBinarySharedObject )()
5659 self .nameToModuleHash = Dict (str , str )()
5760
58- self .modulesMarkedValid = set ()
59- self .modulesMarkedInvalid = set ()
61+ self .moduleManifestsLoaded = set ()
6062
6163 for moduleHash in os .listdir (self .cacheDir ):
6264 if len (moduleHash ) == 40 :
6365 self .loadNameManifestFromStoredModuleByHash (moduleHash )
6466
65- self .targetsLoaded = {}
67+ # the set of functions with an associated module in loadedBinarySharedObjects
68+ self .targetsLoaded : Dict [str , TypedCallTarget ] = {}
6669
67- def hasSymbol ( self , linkName ):
68- return linkName in self .nameToModuleHash
70+ # the set of functions with linked and validated globals (i.e. ready to be run).
71+ self .targetsValidated = set ()
6972
70- def getTarget (self , linkName ):
71- assert self .hasSymbol (linkName )
73+ self .function_dependency_graph = DirectedGraph ()
74+ # dict from function linkname to list of global names (should be llvm keys in serialisedGlobalDefinitions)
75+ self .global_dependencies = Dict (str , ListOf (str ))()
7276
73- self .loadForSymbol (linkName )
77+ def hasSymbol (self , linkName : str ) -> bool :
78+ """NB this will return True even if the linkName is ultimately unretrievable."""
79+ return linkName in self .nameToModuleHash
7480
81+ def getTarget (self , linkName : str ) -> TypedCallTarget :
82+ if not self .hasSymbol (linkName ):
83+ raise ValueError (f'symbol not found for linkName { linkName } ' )
84+ self .loadForSymbol (linkName )
7585 return self .targetsLoaded [linkName ]
7686
77- def markModuleHashInvalid (self , hashstr ) :
78- with open ( os . path . join ( self . cacheDir , hashstr , "marked_invalid" ), "w" ):
79- pass
87+ def dependencies (self , linkName : str ) -> Optional [ List [ str ]] :
88+ """Returns all the function names that `linkName` depends on"""
89+ return list ( self . function_dependency_graph . outgoing ( linkName ))
8090
81- def loadForSymbol (self , linkName ):
91+ def loadForSymbol (self , linkName : str ) -> None :
92+ """Loads the whole module, and any submodules, into LoadedBinarySharedObjects"""
8293 moduleHash = self .nameToModuleHash [linkName ]
8394
8495 self .loadModuleByHash (moduleHash )
8596
86- def loadModuleByHash (self , moduleHash ):
97+ if linkName not in self .targetsValidated :
98+ dependantFuncs = self .dependencies (linkName ) + [linkName ]
99+ globalsToLink = {} # dict from modulehash to list of globals.
100+ for funcName in dependantFuncs :
101+ if funcName not in self .targetsValidated :
102+ funcModuleHash = self .nameToModuleHash [funcName ]
103+ # append to the list of globals to link for a given module. TODO: optimise this, don't double-link.
104+ globalsToLink [funcModuleHash ] = globalsToLink .get (funcModuleHash , []) + self .global_dependencies .get (funcName , [])
105+
106+ for moduleHash , globs in globalsToLink .items (): # this works because loadModuleByHash loads submodules too.
107+ if globs :
108+ definitionsToLink = {x : self .loadedBinarySharedObjects [moduleHash ].serializedGlobalVariableDefinitions [x ]
109+ for x in globs
110+ }
111+ self .loadedBinarySharedObjects [moduleHash ].linkGlobalVariables (definitionsToLink )
112+ if not self .loadedBinarySharedObjects [moduleHash ].validateGlobalVariables (definitionsToLink ):
113+ raise RuntimeError ('failed to validate globals when loading:' , linkName )
114+
115+ self .targetsValidated .update (dependantFuncs )
116+
117+ def loadModuleByHash (self , moduleHash : str ) -> None :
87118 """Load a module by name.
88119
89120 As we load, place all the newly imported typed call targets into
90121 'nameToTypedCallTarget' so that the rest of the system knows what functions
91122 have been uncovered.
92123 """
93- if moduleHash in self .loadedModules :
94- return True
124+ if moduleHash in self .loadedBinarySharedObjects :
125+ return
95126
96127 targetDir = os .path .join (self .cacheDir , moduleHash )
97128
98- try :
99- with open (os .path .join (targetDir , "type_manifest.dat" ), "rb" ) as f :
100- callTargets = SerializationContext ().deserialize (f .read ())
129+ # TODO (Will) - store these names as module consts, use one .dat only
130+ with open (os .path .join (targetDir , "type_manifest.dat" ), "rb" ) as f :
131+ callTargets = SerializationContext ().deserialize (f .read ())
101132
102- with open (os .path .join (targetDir , "globals_manifest.dat" ), "rb" ) as f :
103- globalVarDefs = SerializationContext ().deserialize (f .read ())
133+ with open (os .path .join (targetDir , "globals_manifest.dat" ), "rb" ) as f :
134+ serializedGlobalVarDefs = SerializationContext ().deserialize (f .read ())
104135
105- with open (os .path .join (targetDir , "native_type_manifest.dat" ), "rb" ) as f :
106- functionNameToNativeType = SerializationContext ().deserialize (f .read ())
136+ with open (os .path .join (targetDir , "native_type_manifest.dat" ), "rb" ) as f :
137+ functionNameToNativeType = SerializationContext ().deserialize (f .read ())
138+
139+ with open (os .path .join (targetDir , "submodules.dat" ), "rb" ) as f :
140+ submodules = SerializationContext ().deserialize (f .read (), ListOf (str ))
107141
108- with open (os .path .join (targetDir , "submodules.dat" ), "rb" ) as f :
109- submodules = SerializationContext ().deserialize (f .read (), ListOf (str ))
110- except Exception :
111- self .markModuleHashInvalid (moduleHash )
112- return False
142+ with open (os .path .join (targetDir , "function_dependencies.dat" ), "rb" ) as f :
143+ dependency_edgelist = SerializationContext ().deserialize (f .read ())
113144
114- if not LoadedModule .validateGlobalVariables (globalVarDefs ):
115- self .markModuleHashInvalid (moduleHash )
116- return False
145+ with open (os .path .join (targetDir , "global_dependencies.dat" ), "rb" ) as f :
146+ globalDependencies = SerializationContext ().deserialize (f .read ())
117147
118148 # load the submodules first
119149 for submodule in submodules :
120- if not self .loadModuleByHash (submodule ):
121- return False
150+ self .loadModuleByHash (submodule )
122151
123152 modulePath = os .path .join (targetDir , "module.so" )
124153
125154 loaded = BinarySharedObject .fromDisk (
126155 modulePath ,
127- globalVarDefs ,
128- functionNameToNativeType
156+ serializedGlobalVarDefs ,
157+ functionNameToNativeType ,
158+ globalDependencies
159+
129160 ).loadFromPath (modulePath )
130161
131- self .loadedModules [moduleHash ] = loaded
162+ self .loadedBinarySharedObjects [moduleHash ] = loaded
132163
133164 self .targetsLoaded .update (callTargets )
134165
135- return True
166+ assert not any (key in self .global_dependencies for key in globalDependencies ) # should only happen if there's a hash collision.
167+ self .global_dependencies .update (globalDependencies )
168+
169+ # update the cache's dependency graph with our new edges.
170+ for function_name , dependant_function_name in dependency_edgelist :
171+ self .function_dependency_graph .addEdge (source = function_name , dest = dependant_function_name )
136172
137- def addModule (self , binarySharedObject , nameToTypedCallTarget , linkDependencies ):
173+ def addModule (self , binarySharedObject , nameToTypedCallTarget , linkDependencies , dependencyEdgelist ):
138174 """Add new code to the compiler cache.
139175
140176 Args:
141- binarySharedObject - a BinarySharedObject containing the actual assembler
142- we've compiled
143- nameToTypedCallTarget - a dict from linkname to TypedCallTarget telling us
144- the formal python types for all the objects
145- linkDependencies - a set of linknames we depend on directly.
177+ binarySharedObject: a BinarySharedObject containing the actual assembler
178+ we've compiled.
179+ nameToTypedCallTarget: a dict from linkname to TypedCallTarget telling us
180+ the formal python types for all the objects.
181+ linkDependencies: a set of linknames we depend on directly.
182+ dependencyEdgelist (list): a list of source, dest pairs giving the set of dependency graph for the
183+ module.
146184 """
147185 dependentHashes = set ()
148186
149187 for name in linkDependencies :
150188 dependentHashes .add (self .nameToModuleHash [name ])
151189
152- path , hashToUse = self .writeModuleToDisk (binarySharedObject , nameToTypedCallTarget , dependentHashes )
190+ path , hashToUse = self .writeModuleToDisk (binarySharedObject , nameToTypedCallTarget , dependentHashes , dependencyEdgelist )
153191
154- self .loadedModules [hashToUse ] = (
192+ self .loadedBinarySharedObjects [hashToUse ] = (
155193 binarySharedObject .loadFromPath (os .path .join (path , "module.so" ))
156194 )
157195
158196 for n in binarySharedObject .definedSymbols :
159197 self .nameToModuleHash [n ] = hashToUse
160198
161- def loadNameManifestFromStoredModuleByHash ( self , moduleHash ):
162- if moduleHash in self .modulesMarkedValid :
163- return True
164-
165- targetDir = os . path . join ( self . cacheDir , moduleHash )
199+ # link & validate all globals for the new module
200+ self .loadedBinarySharedObjects [ hashToUse ]. linkGlobalVariables ()
201+ if not self . loadedBinarySharedObjects [ hashToUse ]. validateGlobalVariables (
202+ self . loadedBinarySharedObjects [ hashToUse ]. serializedGlobalVariableDefinitions ):
203+ raise RuntimeError ( 'failed to validate globals in new module:' , hashToUse )
166204
167- # ignore 'marked invalid'
168- if os . path . exists ( os . path . join ( targetDir , "marked_invalid" )) :
169- # just bail - don't try to read it now
205+ def loadNameManifestFromStoredModuleByHash ( self , moduleHash ) -> None :
206+ if moduleHash in self . moduleManifestsLoaded :
207+ return
170208
171- # for the moment, we don't try to clean up the cache, because
172- # we can't be sure that some process is not still reading the
173- # old files.
174- self .modulesMarkedInvalid .add (moduleHash )
175- return False
209+ targetDir = os .path .join (self .cacheDir , moduleHash )
176210
177211 with open (os .path .join (targetDir , "submodules.dat" ), "rb" ) as f :
178212 submodules = SerializationContext ().deserialize (f .read (), ListOf (str ))
179213
180214 for subHash in submodules :
181- if not self .loadNameManifestFromStoredModuleByHash (subHash ):
182- self .markModuleHashInvalid (subHash )
183- return False
215+ self .loadNameManifestFromStoredModuleByHash (subHash )
184216
185217 with open (os .path .join (targetDir , "name_manifest.dat" ), "rb" ) as f :
186218 self .nameToModuleHash .update (
187219 SerializationContext ().deserialize (f .read (), Dict (str , str ))
188220 )
189221
190- self .modulesMarkedValid .add (moduleHash )
191-
192- return True
222+ self .moduleManifestsLoaded .add (moduleHash )
193223
194- def writeModuleToDisk (self , binarySharedObject , nameToTypedCallTarget , submodules ):
224+ def writeModuleToDisk (self , binarySharedObject , nameToTypedCallTarget , submodules , dependencyEdgelist ):
195225 """Write out a disk representation of this module.
196226
197227 This includes writing both the shared object, a manifest of the function names
@@ -244,11 +274,17 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule
244274
245275 # write the type manifest
246276 with open (os .path .join (tempTargetDir , "globals_manifest.dat" ), "wb" ) as f :
247- f .write (SerializationContext ().serialize (binarySharedObject .globalVariableDefinitions ))
277+ f .write (SerializationContext ().serialize (binarySharedObject .serializedGlobalVariableDefinitions ))
248278
249279 with open (os .path .join (tempTargetDir , "submodules.dat" ), "wb" ) as f :
250280 f .write (SerializationContext ().serialize (ListOf (str )(submodules ), ListOf (str )))
251281
282+ with open (os .path .join (tempTargetDir , "function_dependencies.dat" ), "wb" ) as f :
283+ f .write (SerializationContext ().serialize (dependencyEdgelist )) # might need a listof
284+
285+ with open (os .path .join (tempTargetDir , "global_dependencies.dat" ), "wb" ) as f :
286+ f .write (SerializationContext ().serialize (binarySharedObject .globalDependencies ))
287+
252288 try :
253289 os .rename (tempTargetDir , targetDir )
254290 except IOError :
@@ -264,7 +300,7 @@ def function_pointer_by_name(self, linkName):
264300 if moduleHash is None :
265301 raise Exception ("Can't find a module for " + linkName )
266302
267- if moduleHash not in self .loadedModules :
303+ if moduleHash not in self .loadedBinarySharedObjects :
268304 self .loadForSymbol (linkName )
269305
270- return self .loadedModules [moduleHash ].functionPointers [linkName ]
306+ return self .loadedBinarySharedObjects [moduleHash ].functionPointers [linkName ]
0 commit comments