44import sys
55import uuid
66import time
7+ import json
8+ import fnmatch
79import multiprocessing
810from typing import (
911 List ,
1618 Callable ,
1719)
1820from collections import deque
21+ from pathlib import Path
1922
2023import ctypes
2124
2932 LlamaDiskCache , # type: ignore
3033 LlamaRAMCache , # type: ignore
3134)
32- from .llama_tokenizer import (
33- BaseLlamaTokenizer ,
34- LlamaTokenizer
35- )
35+ from .llama_tokenizer import BaseLlamaTokenizer , LlamaTokenizer
3636import llama_cpp .llama_cpp as llama_cpp
3737import llama_cpp .llama_chat_format as llama_chat_format
3838
5050 _LlamaSamplingContext , # type: ignore
5151)
5252from ._logger import set_verbose
53- from ._utils import (
54- suppress_stdout_stderr
55- )
53+ from ._utils import suppress_stdout_stderr
5654
5755
5856class Llama :
@@ -189,7 +187,11 @@ def __init__(
189187 Llama .__backend_initialized = True
190188
191189 if isinstance (numa , bool ):
192- self .numa = llama_cpp .GGML_NUMA_STRATEGY_DISTRIBUTE if numa else llama_cpp .GGML_NUMA_STRATEGY_DISABLED
190+ self .numa = (
191+ llama_cpp .GGML_NUMA_STRATEGY_DISTRIBUTE
192+ if numa
193+ else llama_cpp .GGML_NUMA_STRATEGY_DISABLED
194+ )
193195 else :
194196 self .numa = numa
195197
@@ -246,17 +248,17 @@ def __init__(
246248 else :
247249 raise ValueError (f"Unknown value type for { k } : { v } " )
248250
249- self ._kv_overrides_array [
250- - 1
251- ]. key = b" \0 " # ensure sentinel element is zeroed
251+ self ._kv_overrides_array [- 1 ]. key = (
252+ b" \0 " # ensure sentinel element is zeroed
253+ )
252254 self .model_params .kv_overrides = self ._kv_overrides_array
253255
254256 self .n_batch = min (n_ctx , n_batch ) # ???
255257 self .n_threads = n_threads or max (multiprocessing .cpu_count () // 2 , 1 )
256258 self .n_threads_batch = n_threads_batch or max (
257259 multiprocessing .cpu_count () // 2 , 1
258260 )
259-
261+
260262 # Context Params
261263 self .context_params = llama_cpp .llama_context_default_params ()
262264 self .context_params .seed = seed
@@ -289,7 +291,9 @@ def __init__(
289291 )
290292 self .context_params .yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
291293 self .context_params .mul_mat_q = mul_mat_q
292- self .context_params .logits_all = logits_all if draft_model is None else True # Must be set to True for speculative decoding
294+ self .context_params .logits_all = (
295+ logits_all if draft_model is None else True
296+ ) # Must be set to True for speculative decoding
293297 self .context_params .embedding = embedding
294298 self .context_params .offload_kqv = offload_kqv
295299
@@ -379,8 +383,14 @@ def __init__(
379383 if self .verbose :
380384 print (f"Model metadata: { self .metadata } " , file = sys .stderr )
381385
382- if self .chat_format is None and self .chat_handler is None and "tokenizer.chat_template" in self .metadata :
383- chat_format = llama_chat_format .guess_chat_format_from_gguf_metadata (self .metadata )
386+ if (
387+ self .chat_format is None
388+ and self .chat_handler is None
389+ and "tokenizer.chat_template" in self .metadata
390+ ):
391+ chat_format = llama_chat_format .guess_chat_format_from_gguf_metadata (
392+ self .metadata
393+ )
384394
385395 if chat_format is not None :
386396 self .chat_format = chat_format
@@ -406,9 +416,7 @@ def __init__(
406416 print (f"Using chat bos_token: { bos_token } " , file = sys .stderr )
407417
408418 self .chat_handler = llama_chat_format .Jinja2ChatFormatter (
409- template = template ,
410- eos_token = eos_token ,
411- bos_token = bos_token
419+ template = template , eos_token = eos_token , bos_token = bos_token
412420 ).to_chat_handler ()
413421
414422 if self .chat_format is None and self .chat_handler is None :
@@ -459,7 +467,9 @@ def tokenize(
459467 """
460468 return self .tokenizer_ .tokenize (text , add_bos , special )
461469
462- def detokenize (self , tokens : List [int ], prev_tokens : Optional [List [int ]] = None ) -> bytes :
470+ def detokenize (
471+ self , tokens : List [int ], prev_tokens : Optional [List [int ]] = None
472+ ) -> bytes :
463473 """Detokenize a list of tokens.
464474
465475 Args:
@@ -565,7 +575,7 @@ def sample(
565575 logits [:] = (
566576 logits_processor (self ._input_ids , logits )
567577 if idx is None
568- else logits_processor (self ._input_ids [:idx + 1 ], logits )
578+ else logits_processor (self ._input_ids [: idx + 1 ], logits )
569579 )
570580
571581 sampling_params = _LlamaSamplingParams (
@@ -707,7 +717,9 @@ def generate(
707717
708718 if self .draft_model is not None :
709719 self .input_ids [self .n_tokens : self .n_tokens + len (tokens )] = tokens
710- draft_tokens = self .draft_model (self .input_ids [:self .n_tokens + len (tokens )])
720+ draft_tokens = self .draft_model (
721+ self .input_ids [: self .n_tokens + len (tokens )]
722+ )
711723 tokens .extend (
712724 draft_tokens .astype (int )[
713725 : self ._n_ctx - self .n_tokens - len (tokens )
@@ -792,6 +804,7 @@ def embed(
792804
793805 # decode and fetch embeddings
794806 data : List [List [float ]] = []
807+
795808 def decode_batch (n_seq : int ):
796809 assert self ._ctx .ctx is not None
797810 llama_cpp .llama_kv_cache_clear (self ._ctx .ctx )
@@ -800,9 +813,9 @@ def decode_batch(n_seq: int):
800813
801814 # store embeddings
802815 for i in range (n_seq ):
803- embedding : List [float ] = llama_cpp .llama_get_embeddings_ith (self . _ctx . ctx , i )[
804- : n_embd
805- ]
816+ embedding : List [float ] = llama_cpp .llama_get_embeddings_ith (
817+ self . _ctx . ctx , i
818+ )[: n_embd ]
806819 if normalize :
807820 norm = float (np .linalg .norm (embedding ))
808821 embedding = [v / norm for v in embedding ]
@@ -1669,12 +1682,13 @@ def create_chat_completion_openai_v1(
16691682 """
16701683 try :
16711684 from openai .types .chat import ChatCompletion , ChatCompletionChunk
1672- stream = kwargs .get ("stream" , False ) # type: ignore
1685+
1686+ stream = kwargs .get ("stream" , False ) # type: ignore
16731687 assert isinstance (stream , bool )
16741688 if stream :
1675- return (ChatCompletionChunk (** chunk ) for chunk in self .create_chat_completion (* args , ** kwargs )) # type: ignore
1689+ return (ChatCompletionChunk (** chunk ) for chunk in self .create_chat_completion (* args , ** kwargs )) # type: ignore
16761690 else :
1677- return ChatCompletion (** self .create_chat_completion (* args , ** kwargs )) # type: ignore
1691+ return ChatCompletion (** self .create_chat_completion (* args , ** kwargs )) # type: ignore
16781692 except ImportError :
16791693 raise ImportError (
16801694 "To use create_chat_completion_openai_v1, you must install the openai package."
@@ -1866,7 +1880,88 @@ def longest_token_prefix(a: Sequence[int], b: Sequence[int]):
18661880 break
18671881 return longest_prefix
18681882
1883+ @classmethod
1884+ def from_pretrained (
1885+ cls ,
1886+ repo_id : str ,
1887+ filename : Optional [str ],
1888+ local_dir : Optional [Union [str , os .PathLike [str ]]] = "." ,
1889+ local_dir_use_symlinks : Union [bool , Literal ["auto" ]] = "auto" ,
1890+ ** kwargs : Any ,
1891+ ) -> "Llama" :
1892+ """Create a Llama model from a pretrained model name or path.
1893+ This method requires the huggingface-hub package.
1894+ You can install it with `pip install huggingface-hub`.
1895+
1896+ Args:
1897+ repo_id: The model repo id.
1898+ filename: A filename or glob pattern to match the model file in the repo.
1899+ local_dir: The local directory to save the model to.
1900+ local_dir_use_symlinks: Whether to use symlinks when downloading the model.
1901+ **kwargs: Additional keyword arguments to pass to the Llama constructor.
1902+
1903+ Returns:
1904+ A Llama model."""
1905+ try :
1906+ from huggingface_hub import hf_hub_download , HfFileSystem
1907+ from huggingface_hub .utils import validate_repo_id
1908+ except ImportError :
1909+ raise ImportError (
1910+ "Llama.from_pretrained requires the huggingface-hub package. "
1911+ "You can install it with `pip install huggingface-hub`."
1912+ )
1913+
1914+ validate_repo_id (repo_id )
1915+
1916+ hffs = HfFileSystem ()
1917+
1918+ files = [
1919+ file ["name" ] if isinstance (file , dict ) else file
1920+ for file in hffs .ls (repo_id )
1921+ ]
1922+
1923+ # split each file into repo_id, subfolder, filename
1924+ file_list : List [str ] = []
1925+ for file in files :
1926+ rel_path = Path (file ).relative_to (repo_id )
1927+ file_list .append (str (rel_path ))
18691928
1929+ matching_files = [file for file in file_list if fnmatch .fnmatch (file , filename )] # type: ignore
1930+
1931+ if len (matching_files ) == 0 :
1932+ raise ValueError (
1933+ f"No file found in { repo_id } that match { filename } \n \n "
1934+ f"Available Files:\n { json .dumps (file_list )} "
1935+ )
1936+
1937+ if len (matching_files ) > 1 :
1938+ raise ValueError (
1939+ f"Multiple files found in { repo_id } matching { filename } \n \n "
1940+ f"Available Files:\n { json .dumps (files )} "
1941+ )
1942+
1943+ (matching_file ,) = matching_files
1944+
1945+ subfolder = str (Path (matching_file ).parent )
1946+ filename = Path (matching_file ).name
1947+
1948+ local_dir = "."
1949+
1950+ # download the file
1951+ hf_hub_download (
1952+ repo_id = repo_id ,
1953+ local_dir = local_dir ,
1954+ filename = filename ,
1955+ subfolder = subfolder ,
1956+ local_dir_use_symlinks = local_dir_use_symlinks ,
1957+ )
1958+
1959+ model_path = os .path .join (local_dir , filename )
1960+
1961+ return cls (
1962+ model_path = model_path ,
1963+ ** kwargs ,
1964+ )
18701965
18711966
18721967class LlamaState :
0 commit comments