@@ -255,13 +255,14 @@ class CreateCompletionRequest(BaseModel):
255255 )
256256 presence_penalty : Optional [float ] = presence_penalty_field
257257 frequency_penalty : Optional [float ] = frequency_penalty_field
258+ logit_bias : Optional [Dict [str , float ]] = Field (None )
259+ logit_bias_type : Optional [Literal ["input_ids" , "tokens" ]] = Field (None )
258260
259261 # ignored or currently unsupported
260262 model : Optional [str ] = model_field
261263 n : Optional [int ] = 1
262264 logprobs : Optional [int ] = Field (None )
263265 best_of : Optional [int ] = 1
264- logit_bias : Optional [Dict [str , float ]] = Field (None )
265266 user : Optional [str ] = Field (None )
266267
267268 # llama.cpp specific parameters
@@ -280,6 +281,39 @@ class Config:
280281CreateCompletionResponse = create_model_from_typeddict (llama_cpp .Completion )
281282
282283
284+ def make_logit_bias_processor (
285+ llama : llama_cpp .Llama ,
286+ logit_bias : Dict [str , float ],
287+ logit_bias_type : Optional [Literal ["input_ids" , "tokens" ]],
288+ ):
289+ if logit_bias_type is None :
290+ logit_bias_type = "input_ids"
291+
292+ to_bias : Dict [int , float ] = {}
293+ if logit_bias_type == "input_ids" :
294+ for input_id , score in logit_bias .items ():
295+ input_id = int (input_id )
296+ to_bias [input_id ] = score
297+
298+ elif logit_bias_type == "tokens" :
299+ for token , score in logit_bias .items ():
300+ token = token .encode ('utf-8' )
301+ for input_id in llama .tokenize (token , add_bos = False ):
302+ to_bias [input_id ] = score
303+
304+ def logit_bias_processor (
305+ input_ids : List [int ],
306+ scores : List [float ],
307+ ) -> List [float ]:
308+ new_scores = [None ] * len (scores )
309+ for input_id , score in enumerate (scores ):
310+ new_scores [input_id ] = score + to_bias .get (input_id , 0.0 )
311+
312+ return new_scores
313+
314+ return logit_bias_processor
315+
316+
283317@router .post (
284318 "/v1/completions" ,
285319 response_model = CreateCompletionResponse ,
@@ -297,9 +331,16 @@ async def create_completion(
297331 "n" ,
298332 "best_of" ,
299333 "logit_bias" ,
334+ "logit_bias_type" ,
300335 "user" ,
301336 }
302337 kwargs = body .dict (exclude = exclude )
338+
339+ if body .logit_bias is not None :
340+ kwargs ['logits_processor' ] = llama_cpp .LogitsProcessorList ([
341+ make_logit_bias_processor (llama , body .logit_bias , body .logit_bias_type ),
342+ ])
343+
303344 if body .stream :
304345 send_chan , recv_chan = anyio .create_memory_object_stream (10 )
305346
@@ -378,11 +419,12 @@ class CreateChatCompletionRequest(BaseModel):
378419 stream : bool = stream_field
379420 presence_penalty : Optional [float ] = presence_penalty_field
380421 frequency_penalty : Optional [float ] = frequency_penalty_field
422+ logit_bias : Optional [Dict [str , float ]] = Field (None )
423+ logit_bias_type : Optional [Literal ["input_ids" , "tokens" ]] = Field (None )
381424
382425 # ignored or currently unsupported
383426 model : Optional [str ] = model_field
384427 n : Optional [int ] = 1
385- logit_bias : Optional [Dict [str , float ]] = Field (None )
386428 user : Optional [str ] = Field (None )
387429
388430 # llama.cpp specific parameters
@@ -419,9 +461,16 @@ async def create_chat_completion(
419461 exclude = {
420462 "n" ,
421463 "logit_bias" ,
464+ "logit_bias_type" ,
422465 "user" ,
423466 }
424467 kwargs = body .dict (exclude = exclude )
468+
469+ if body .logit_bias is not None :
470+ kwargs ['logits_processor' ] = llama_cpp .LogitsProcessorList ([
471+ make_logit_bias_processor (llama , body .logit_bias , body .logit_bias_type ),
472+ ])
473+
425474 if body .stream :
426475 send_chan , recv_chan = anyio .create_memory_object_stream (10 )
427476
0 commit comments