@@ -261,14 +261,16 @@ def eval(self, tokens: Sequence[llama_cpp.llama_token]):
261261 ]
262262 self .eval_logits .extend (logits )
263263
264- def _sample_top_p_top_k (
264+ def _sample (
265265 self ,
266266 last_n_tokens_data , # type: llama_cpp.Array[llama_cpp.llama_token]
267267 last_n_tokens_size : llama_cpp .c_int ,
268268 top_k : llama_cpp .c_int ,
269269 top_p : llama_cpp .c_float ,
270270 temp : llama_cpp .c_float ,
271271 repeat_penalty : llama_cpp .c_float ,
272+ frequency_penalty : llama_cpp .c_float ,
273+ presence_penalty : llama_cpp .c_float ,
272274 ):
273275 assert self .ctx is not None
274276 assert len (self .eval_logits ) > 0
@@ -298,6 +300,14 @@ def _sample_top_p_top_k(
298300 candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
299301 penalty = repeat_penalty ,
300302 )
303+ llama_cpp .llama_sample_frequency_and_presence_penalties (
304+ ctx = self .ctx ,
305+ candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
306+ last_tokens_data = last_n_tokens_data ,
307+ last_tokens_size = last_n_tokens_size ,
308+ alpha_frequency = frequency_penalty ,
309+ alpha_presence = presence_penalty ,
310+ )
301311 if float (temp .value ) == 0.0 :
302312 return llama_cpp .llama_sample_token_greedy (
303313 ctx = self .ctx ,
@@ -344,6 +354,8 @@ def sample(
344354 top_p : float ,
345355 temp : float ,
346356 repeat_penalty : float ,
357+ frequency_penalty : float = 0.0 ,
358+ presence_penalty : float = 0.0 ,
347359 ):
348360 """Sample a token from the model.
349361
@@ -360,7 +372,7 @@ def sample(
360372 last_n_tokens_data = [llama_cpp .llama_token (0 )] * max (
361373 0 , self .last_n_tokens_size - len (self .eval_tokens )
362374 ) + list (self .eval_tokens )[- self .last_n_tokens_size :]
363- return self ._sample_top_p_top_k (
375+ return self ._sample (
364376 last_n_tokens_data = (llama_cpp .llama_token * self .last_n_tokens_size )(
365377 * last_n_tokens_data
366378 ),
@@ -369,6 +381,8 @@ def sample(
369381 top_p = llama_cpp .c_float (top_p ),
370382 temp = llama_cpp .c_float (temp ),
371383 repeat_penalty = llama_cpp .c_float (repeat_penalty ),
384+ frequency_penalty = llama_cpp .c_float (frequency_penalty ),
385+ presence_penalty = llama_cpp .c_float (presence_penalty ),
372386 )
373387
374388 def generate (
@@ -378,6 +392,8 @@ def generate(
378392 top_p : float ,
379393 temp : float ,
380394 repeat_penalty : float ,
395+ frequency_penalty : float = 0.0 ,
396+ presence_penalty : float = 0.0 ,
381397 reset : bool = True ,
382398 ) -> Generator [
383399 llama_cpp .llama_token , Optional [Sequence [llama_cpp .llama_token ]], None
@@ -431,6 +447,8 @@ def generate(
431447 top_k = top_k ,
432448 top_p = top_p ,
433449 temp = temp ,
450+ frequency_penalty = frequency_penalty ,
451+ presence_penalty = presence_penalty ,
434452 repeat_penalty = repeat_penalty ,
435453 )
436454 tokens_or_none = yield token
@@ -505,6 +523,8 @@ def _create_completion(
505523 logprobs : Optional [int ] = None ,
506524 echo : bool = False ,
507525 stop : Optional [List [str ]] = [],
526+ frequency_penalty : float = 0.0 ,
527+ presence_penalty : float = 0.0 ,
508528 repeat_penalty : float = 1.1 ,
509529 top_k : int = 40 ,
510530 stream : bool = False ,
@@ -563,6 +583,8 @@ def _create_completion(
563583 top_k = top_k ,
564584 top_p = top_p ,
565585 temp = temperature ,
586+ frequency_penalty = frequency_penalty ,
587+ presence_penalty = presence_penalty ,
566588 repeat_penalty = repeat_penalty ,
567589 ):
568590 if token == llama_cpp .llama_token_eos ():
@@ -737,6 +759,8 @@ def create_completion(
737759 logprobs : Optional [int ] = None ,
738760 echo : bool = False ,
739761 stop : Optional [List [str ]] = [],
762+ frequency_penalty : float = 0.0 ,
763+ presence_penalty : float = 0.0 ,
740764 repeat_penalty : float = 1.1 ,
741765 top_k : int = 40 ,
742766 stream : bool = False ,
@@ -772,6 +796,8 @@ def create_completion(
772796 logprobs = logprobs ,
773797 echo = echo ,
774798 stop = stop ,
799+ frequency_penalty = frequency_penalty ,
800+ presence_penalty = presence_penalty ,
775801 repeat_penalty = repeat_penalty ,
776802 top_k = top_k ,
777803 stream = stream ,
@@ -792,6 +818,8 @@ def __call__(
792818 logprobs : Optional [int ] = None ,
793819 echo : bool = False ,
794820 stop : Optional [List [str ]] = [],
821+ frequency_penalty : float = 0.0 ,
822+ presence_penalty : float = 0.0 ,
795823 repeat_penalty : float = 1.1 ,
796824 top_k : int = 40 ,
797825 stream : bool = False ,
@@ -827,6 +855,8 @@ def __call__(
827855 logprobs = logprobs ,
828856 echo = echo ,
829857 stop = stop ,
858+ frequency_penalty = frequency_penalty ,
859+ presence_penalty = presence_penalty ,
830860 repeat_penalty = repeat_penalty ,
831861 top_k = top_k ,
832862 stream = stream ,
@@ -899,6 +929,8 @@ def create_chat_completion(
899929 stream : bool = False ,
900930 stop : Optional [List [str ]] = [],
901931 max_tokens : int = 256 ,
932+ presence_penalty : float = 0.0 ,
933+ frequency_penalty : float = 0.0 ,
902934 repeat_penalty : float = 1.1 ,
903935 ) -> Union [ChatCompletion , Iterator [ChatCompletionChunk ]]:
904936 """Generate a chat completion from a list of messages.
@@ -932,6 +964,8 @@ def create_chat_completion(
932964 stream = stream ,
933965 max_tokens = max_tokens ,
934966 repeat_penalty = repeat_penalty ,
967+ presence_penalty = presence_penalty ,
968+ frequency_penalty = frequency_penalty ,
935969 )
936970 if stream :
937971 chunks : Iterator [CompletionChunk ] = completion_or_chunks # type: ignore
0 commit comments