1- import pytest
21import llama_cpp
32
43MODEL = "./vendor/llama.cpp/models/ggml-vocab.bin"
@@ -15,15 +14,20 @@ def test_llama():
1514 assert llama .detokenize (llama .tokenize (text )) == text
1615
1716
18- @pytest .mark .skip (reason = "need to update sample mocking" )
17+ # @pytest.mark.skip(reason="need to update sample mocking")
1918def test_llama_patch (monkeypatch ):
2019 llama = llama_cpp .Llama (model_path = MODEL , vocab_only = True )
20+ n_vocab = int (llama_cpp .llama_n_vocab (llama .ctx ))
2121
2222 ## Set up mock function
2323 def mock_eval (* args , ** kwargs ):
2424 return 0
25+
26+ def mock_get_logits (* args , ** kwargs ):
27+ return (llama_cpp .c_float * n_vocab )(* [llama_cpp .c_float (0 ) for _ in range (n_vocab )])
2528
2629 monkeypatch .setattr ("llama_cpp.llama_cpp.llama_eval" , mock_eval )
30+ monkeypatch .setattr ("llama_cpp.llama_cpp.llama_get_logits" , mock_get_logits )
2731
2832 output_text = " jumps over the lazy dog."
2933 output_tokens = llama .tokenize (output_text .encode ("utf-8" ))
@@ -38,7 +42,7 @@ def mock_sample(*args, **kwargs):
3842 else :
3943 return token_eos
4044
41- monkeypatch .setattr ("llama_cpp.llama_cpp.llama_sample_top_p_top_k " , mock_sample )
45+ monkeypatch .setattr ("llama_cpp.llama_cpp.llama_sample_token " , mock_sample )
4246
4347 text = "The quick brown fox"
4448
@@ -97,15 +101,19 @@ def test_llama_pickle():
97101
98102 assert llama .detokenize (llama .tokenize (text )) == text
99103
100- @pytest .mark .skip (reason = "need to update sample mocking" )
101104def test_utf8 (monkeypatch ):
102105 llama = llama_cpp .Llama (model_path = MODEL , vocab_only = True )
106+ n_vocab = int (llama_cpp .llama_n_vocab (llama .ctx ))
103107
104108 ## Set up mock function
105109 def mock_eval (* args , ** kwargs ):
106110 return 0
107111
112+ def mock_get_logits (* args , ** kwargs ):
113+ return (llama_cpp .c_float * n_vocab )(* [llama_cpp .c_float (0 ) for _ in range (n_vocab )])
114+
108115 monkeypatch .setattr ("llama_cpp.llama_cpp.llama_eval" , mock_eval )
116+ monkeypatch .setattr ("llama_cpp.llama_cpp.llama_get_logits" , mock_get_logits )
109117
110118 output_text = "😀"
111119 output_tokens = llama .tokenize (output_text .encode ("utf-8" ))
@@ -120,7 +128,7 @@ def mock_sample(*args, **kwargs):
120128 else :
121129 return token_eos
122130
123- monkeypatch .setattr ("llama_cpp.llama_cpp.llama_sample_top_p_top_k " , mock_sample )
131+ monkeypatch .setattr ("llama_cpp.llama_cpp.llama_sample_token " , mock_sample )
124132
125133 ## Test basic completion with utf8 multibyte
126134 n = 0 # reset
0 commit comments