@@ -249,13 +249,14 @@ class CreateCompletionRequest(BaseModel):
249249 )
250250 presence_penalty : Optional [float ] = presence_penalty_field
251251 frequency_penalty : Optional [float ] = frequency_penalty_field
252+ logit_bias : Optional [Dict [str , float ]] = Field (None )
253+ logit_bias_type : Optional [Literal ["input_ids" , "tokens" ]] = Field (None )
252254
253255 # ignored or currently unsupported
254256 model : Optional [str ] = model_field
255257 n : Optional [int ] = 1
256258 logprobs : Optional [int ] = Field (None )
257259 best_of : Optional [int ] = 1
258- logit_bias : Optional [Dict [str , float ]] = Field (None )
259260 user : Optional [str ] = Field (None )
260261
261262 # llama.cpp specific parameters
@@ -274,6 +275,39 @@ class Config:
274275CreateCompletionResponse = create_model_from_typeddict (llama_cpp .Completion )
275276
276277
278+ def make_logit_bias_processor (
279+ llama : llama_cpp .Llama ,
280+ logit_bias : Dict [str , float ],
281+ logit_bias_type : Optional [Literal ["input_ids" , "tokens" ]],
282+ ):
283+ if logit_bias_type is None :
284+ logit_bias_type = "input_ids"
285+
286+ to_bias : Dict [int , float ] = {}
287+ if logit_bias_type == "input_ids" :
288+ for input_id , score in logit_bias .items ():
289+ input_id = int (input_id )
290+ to_bias [input_id ] = score
291+
292+ elif logit_bias_type == "tokens" :
293+ for token , score in logit_bias .items ():
294+ token = token .encode ('utf-8' )
295+ for input_id in llama .tokenize (token , add_bos = False ):
296+ to_bias [input_id ] = score
297+
298+ def logit_bias_processor (
299+ input_ids : List [int ],
300+ scores : List [float ],
301+ ) -> List [float ]:
302+ new_scores = [None ] * len (scores )
303+ for input_id , score in enumerate (scores ):
304+ new_scores [input_id ] = score + to_bias .get (input_id , 0.0 )
305+
306+ return new_scores
307+
308+ return logit_bias_processor
309+
310+
277311@router .post (
278312 "/v1/completions" ,
279313 response_model = CreateCompletionResponse ,
@@ -291,9 +325,16 @@ async def create_completion(
291325 "n" ,
292326 "best_of" ,
293327 "logit_bias" ,
328+ "logit_bias_type" ,
294329 "user" ,
295330 }
296331 kwargs = body .dict (exclude = exclude )
332+
333+ if body .logit_bias is not None :
334+ kwargs ['logits_processor' ] = llama_cpp .LogitsProcessorList ([
335+ make_logit_bias_processor (llama , body .logit_bias , body .logit_bias_type ),
336+ ])
337+
297338 if body .stream :
298339 send_chan , recv_chan = anyio .create_memory_object_stream (10 )
299340
@@ -372,11 +413,12 @@ class CreateChatCompletionRequest(BaseModel):
372413 stream : bool = stream_field
373414 presence_penalty : Optional [float ] = presence_penalty_field
374415 frequency_penalty : Optional [float ] = frequency_penalty_field
416+ logit_bias : Optional [Dict [str , float ]] = Field (None )
417+ logit_bias_type : Optional [Literal ["input_ids" , "tokens" ]] = Field (None )
375418
376419 # ignored or currently unsupported
377420 model : Optional [str ] = model_field
378421 n : Optional [int ] = 1
379- logit_bias : Optional [Dict [str , float ]] = Field (None )
380422 user : Optional [str ] = Field (None )
381423
382424 # llama.cpp specific parameters
@@ -413,9 +455,16 @@ async def create_chat_completion(
413455 exclude = {
414456 "n" ,
415457 "logit_bias" ,
458+ "logit_bias_type" ,
416459 "user" ,
417460 }
418461 kwargs = body .dict (exclude = exclude )
462+
463+ if body .logit_bias is not None :
464+ kwargs ['logits_processor' ] = llama_cpp .LogitsProcessorList ([
465+ make_logit_bias_processor (llama , body .logit_bias , body .logit_bias_type ),
466+ ])
467+
419468 if body .stream :
420469 send_chan , recv_chan = anyio .create_memory_object_stream (10 )
421470
0 commit comments