@@ -136,7 +136,13 @@ def load_shared_library(lib_base_name: str, base_path: pathlib.Path):
136136
137137module_name = "ggml"
138138lib_base_name = pathlib .Path (os .path .abspath (os .path .dirname (__file__ ))) / "lib"
139- lib = load_shared_library (module_name , lib_base_name )
139+ if sys .platform == "win32" :
140+ libs = []
141+ for l in lib_base_name .iterdir ():
142+ if l .suffix == ".dll" :
143+ libs .append (load_shared_library (l .stem , lib_base_name ))
144+ else :
145+ libs = [load_shared_library (module_name , lib_base_name )]
140146
141147
142148#####################################################
@@ -169,30 +175,41 @@ def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCDa
169175
170176F = TypeVar ("F" , bound = Callable [..., Any ])
171177
172- def ctypes_function_for_shared_library (lib : ctypes . CDLL ):
178+ def ctypes_function_for_shared_library (libraries ):
173179 def ctypes_function (
174180 name : str , argtypes : List [Any ], restype : Any , enabled : bool = True
175181 ):
176182 def decorator (f : F ) -> F :
177183 if enabled :
178- func = getattr (lib , name )
179- func .argtypes = argtypes
180- func .restype = restype
181- functools .wraps (f )(func )
182- return func
183- else :
184- def f_ (* args : Any , ** kwargs : Any ):
185- raise RuntimeError (
186- f"Function '{ name } ' is not available in the shared library (enabled=False)"
187- )
188- return cast (F , f_ )
184+ for lib in libraries :
185+ try :
186+ func = getattr (lib , name )
187+ func .argtypes = argtypes
188+ func .restype = restype
189+ functools .wraps (f )(func )
190+ return func
191+ except AttributeError :
192+ pass
193+
194+ def f_ (* args : Any , ** kwargs : Any ):
195+ raise RuntimeError (
196+ f"Function '{ name } ' is not available in the shared library (enabled=False)"
197+ )
198+ return cast (F , f_ )
189199
190200 return decorator
191201
192202 return ctypes_function
193203
194204
195- ggml_function = ctypes_function_for_shared_library (lib )
205+ def findattr (libraries , name ):
206+ for lib in libraries :
207+ if hasattr (lib , name ):
208+ return getattr (lib , name )
209+
210+ return None
211+
212+ ggml_function = ctypes_function_for_shared_library (libs )
196213
197214
198215#####################################################
@@ -8912,8 +8929,8 @@ def ggml_backend_cpu_hbm_buffer_type() -> Optional[ggml_backend_buffer_type_t]:
89128929 ...
89138930
89148931
8915- if hasattr ( lib , "ggml_backend_cpu_hbm_buffer_type" ):
8916- ggml_backend_cpu_hbm_buffer_type = lib . ggml_backend_cpu_hbm_buffer_type
8932+ if findattr ( libs , "ggml_backend_cpu_hbm_buffer_type" ):
8933+ ggml_backend_cpu_hbm_buffer_type = findattr ( libs , " ggml_backend_cpu_hbm_buffer_type" )
89178934 ggml_backend_cpu_hbm_buffer_type .argtypes = []
89188935 ggml_backend_cpu_hbm_buffer_type .restype = ggml_backend_buffer_type_t
89198936
@@ -9709,7 +9726,7 @@ def ggml_backend_register(
97099726#####################################################
97109727
97119728
9712- GGML_USE_CUDA = hasattr ( lib , "ggml_backend_cuda_init" )
9729+ GGML_USE_CUDA = findattr ( libs , "ggml_backend_cuda_init" ) is not None
97139730
97149731
97159732GGML_CUDA_MAX_DEVICES = 16
@@ -9863,7 +9880,7 @@ def ggml_backend_cuda_unregister_host_buffer(
98639880#####################################################
98649881
98659882
9866- GGML_USE_METAL = hasattr ( lib , "ggml_backend_metal_init" )
9883+ GGML_USE_METAL = findattr ( libs , "ggml_backend_metal_init" ) is not None
98679884
98689885
98699886# // max memory buffers that can be mapped to the device
@@ -9978,7 +9995,7 @@ def ggml_backend_metal_capture_next_compute(backend: Union[ggml_backend_t, int],
99789995#####################################################
99799996
99809997
9981- GGML_USE_CLBLAST = hasattr ( lib , "ggml_cl_init" )
9998+ GGML_USE_CLBLAST = findattr ( libs , "ggml_cl_init" ) is not None
99829999
998310000
998410001# GGML_API void ggml_cl_init(void);
@@ -10136,7 +10153,7 @@ def ggml_backend_opencl_host_buffer_type() -> (
1013610153# source: src/ggml-vulkan.h
1013710154#####################################################
1013810155
10139- GGML_USE_VULKAN = hasattr ( lib , "ggml_vk_init_cpu_assist" )
10156+ GGML_USE_VULKAN = findattr ( libs , "ggml_vk_init_cpu_assist" ) is not None
1014010157
1014110158# #define GGML_VK_NAME "Vulkan"
1014210159# #define GGML_VK_MAX_DEVICES 16
@@ -10326,7 +10343,7 @@ def ggml_backend_vk_host_buffer_type() -> Optional[ggml_backend_buffer_type_t]:
1032610343#####################################################
1032710344
1032810345
10329- GGML_USE_RPC = hasattr ( lib , "ggml_backend_rpc_init" )
10346+ GGML_USE_RPC = findattr ( libs , "ggml_backend_rpc_init" ) is not None
1033010347
1033110348
1033210349#define GGML_RPC_MAX_SERVERS 16
0 commit comments