@@ -127,7 +127,9 @@ def __init__(
127127 self .last_n_tokens_size = last_n_tokens_size
128128 self .n_batch = min (n_ctx , n_batch )
129129 self .eval_tokens : Deque [llama_cpp .llama_token ] = deque (maxlen = n_ctx )
130- self .eval_logits : Deque [List [float ]] = deque (maxlen = n_ctx )
130+ self .eval_logits : Deque [List [llama_cpp .c_float ]] = deque (
131+ maxlen = n_ctx if logits_all else 1
132+ )
131133
132134 self .cache : Optional [LlamaCache ] = None
133135
@@ -236,17 +238,90 @@ def eval(self, tokens: Sequence[llama_cpp.llama_token]):
236238 )
237239 if int (return_code ) != 0 :
238240 raise RuntimeError (f"llama_eval returned { return_code } " )
241+ # Save tokens
239242 self .eval_tokens .extend (batch )
240- if self .params .logits_all :
241- n_vocab = llama_cpp .llama_n_vocab (self .ctx )
242- cols = int (n_vocab )
243- rows = n_tokens
244- logits_view = llama_cpp .llama_get_logits (self .ctx )
245- logits = [
246- [logits_view [i * cols + j ] for j in range (cols )]
247- for i in range (rows )
248- ]
249- self .eval_logits .extend (logits )
243+ # Save logits
244+ rows = n_tokens if self .params .logits_all else 1
245+ n_vocab = llama_cpp .llama_n_vocab (self .ctx )
246+ cols = int (n_vocab )
247+ logits_view = llama_cpp .llama_get_logits (self .ctx )
248+ logits : List [List [llama_cpp .c_float ]] = [
249+ [logits_view [i * cols + j ] for j in range (cols )] for i in range (rows )
250+ ]
251+ self .eval_logits .extend (logits )
252+
253+ def _sample_top_p_top_k (
254+ self ,
255+ last_n_tokens_data , # type: llama_cpp.Array[llama_cpp.llama_token]
256+ last_n_tokens_size : llama_cpp .c_int ,
257+ top_k : llama_cpp .c_int ,
258+ top_p : llama_cpp .c_float ,
259+ temp : llama_cpp .c_float ,
260+ repeat_penalty : llama_cpp .c_float ,
261+ ):
262+ assert self .ctx is not None
263+ assert len (self .eval_logits ) > 0
264+ n_vocab = int (llama_cpp .llama_n_vocab (self .ctx ))
265+ logits = self .eval_logits [- 1 ]
266+ data = (llama_cpp .llama_token_data * n_vocab )(
267+ * [
268+ llama_cpp .llama_token_data (
269+ id = llama_cpp .llama_token (i ),
270+ logit = logits [i ],
271+ p = llama_cpp .c_float (0.0 ),
272+ )
273+ for i in range (n_vocab )
274+ ]
275+ )
276+ size = llama_cpp .c_size_t (n_vocab )
277+ sorted = False
278+ candidates = llama_cpp .llama_token_data_array (
279+ data = data ,
280+ size = size ,
281+ sorted = sorted ,
282+ )
283+ llama_cpp .llama_sample_repetition_penalty (
284+ ctx = self .ctx ,
285+ last_tokens_data = last_n_tokens_data ,
286+ last_tokens_size = last_n_tokens_size ,
287+ candidates = llama_cpp .ctypes .pointer (candidates ),
288+ penalty = repeat_penalty ,
289+ )
290+ if temp == 0.0 :
291+ return llama_cpp .llama_sample_token_greedy (
292+ ctx = self .ctx ,
293+ candidates = llama_cpp .ctypes .pointer (candidates ),
294+ )
295+ else :
296+ llama_cpp .llama_sample_top_k (
297+ ctx = self .ctx ,
298+ candidates = llama_cpp .ctypes .pointer (candidates ),
299+ k = top_k ,
300+ )
301+ llama_cpp .llama_sample_tail_free (
302+ ctx = self .ctx ,
303+ candidates = llama_cpp .ctypes .pointer (candidates ),
304+ z = llama_cpp .c_float (1.0 ),
305+ )
306+ llama_cpp .llama_sample_typical (
307+ ctx = self .ctx ,
308+ candidates = llama_cpp .ctypes .pointer (candidates ),
309+ p = llama_cpp .c_float (1.0 )
310+ )
311+ llama_cpp .llama_sample_top_p (
312+ ctx = self .ctx ,
313+ candidates = llama_cpp .ctypes .pointer (candidates ),
314+ p = top_p ,
315+ )
316+ llama_cpp .llama_sample_temperature (
317+ ctx = self .ctx ,
318+ candidates = llama_cpp .ctypes .pointer (candidates ),
319+ temp = temp ,
320+ )
321+ return llama_cpp .llama_sample_token (
322+ ctx = self .ctx ,
323+ candidates = llama_cpp .ctypes .pointer (candidates ),
324+ )
250325
251326 def sample (
252327 self ,
@@ -270,8 +345,7 @@ def sample(
270345 last_n_tokens_data = [llama_cpp .llama_token (0 )] * max (
271346 0 , self .last_n_tokens_size - len (self .eval_tokens )
272347 ) + list (self .eval_tokens )[- self .last_n_tokens_size :]
273- return llama_cpp .llama_sample_top_p_top_k (
274- ctx = self .ctx ,
348+ return self ._sample_top_p_top_k (
275349 last_n_tokens_data = (llama_cpp .llama_token * self .last_n_tokens_size )(
276350 * last_n_tokens_data
277351 ),
@@ -470,15 +544,15 @@ def _create_completion(
470544 all_text = self .detokenize (completion_tokens )
471545
472546 # Contains multi-byte UTF8
473- for k ,char in enumerate (all_text [- 3 :]):
547+ for k , char in enumerate (all_text [- 3 :]):
474548 k = 3 - k
475- for num ,pattern in [(2 , 192 ), (3 , 224 ), (4 , 240 )]:
549+ for num , pattern in [(2 , 192 ), (3 , 224 ), (4 , 240 )]:
476550 # Bitwise AND check
477- if ( num > k and pattern & char == pattern ) :
551+ if num > k and pattern & char == pattern :
478552 multibyte_fix = num - k
479553
480554 # Stop incomplete bytes from passing
481- if ( multibyte_fix > 0 ) :
555+ if multibyte_fix > 0 :
482556 multibyte_fix -= 1
483557 continue
484558
@@ -531,7 +605,9 @@ def _create_completion(
531605 "model" : self .model_path ,
532606 "choices" : [
533607 {
534- "text" : text [returned_characters :].decode ("utf-8" , errors = "ignore" ),
608+ "text" : text [returned_characters :].decode (
609+ "utf-8" , errors = "ignore"
610+ ),
535611 "index" : 0 ,
536612 "logprobs" : None ,
537613 "finish_reason" : finish_reason ,
@@ -558,7 +634,8 @@ def _create_completion(
558634
559635 all_tokens = prompt_tokens + completion_tokens
560636 all_token_strs = [
561- self .detokenize ([token ]).decode ("utf-8" , errors = "ignore" ) for token in all_tokens
637+ self .detokenize ([token ]).decode ("utf-8" , errors = "ignore" )
638+ for token in all_tokens
562639 ]
563640 all_logprobs = [
564641 [Llama .logit_to_logprob (logit ) for logit in row ]
@@ -577,7 +654,9 @@ def _create_completion(
577654 )
578655 token_logprobs .append (sorted_logprobs [int (token )][0 ])
579656 top_logprob = {
580- self .detokenize ([llama_cpp .llama_token (i )]).decode ("utf-8" , errors = "ignore" ): logprob
657+ self .detokenize ([llama_cpp .llama_token (i )]).decode (
658+ "utf-8" , errors = "ignore"
659+ ): logprob
581660 for logprob , i in sorted_logprobs [:logprobs ]
582661 }
583662 top_logprob .update ({token_str : sorted_logprobs [int (token )][0 ]})
0 commit comments