@@ -70,6 +70,14 @@ class OpenAIHTTPBackend(Backend):
7070 the values of these keys will be used as the parameters for the respective
7171 endpoint.
7272 If not provided, no extra query parameters are added.
73+ :param extra_body: Body parameters to include in requests to the OpenAI server.
74+ If "chat_completions", "models", or "text_completions" are included as keys,
75+ the values of these keys will be included in the body for the respective
76+ endpoint.
77+ If not provided, no extra body parameters are added.
78+ :param remove_from_body: Parameters that should be removed from the body of each
79+ request.
80+ If not provided, no parameters are removed from the body.
7381 """
7482
7583 def __init__ (
@@ -85,6 +93,7 @@ def __init__(
8593 max_output_tokens : Optional [int ] = None ,
8694 extra_query : Optional [dict ] = None ,
8795 extra_body : Optional [dict ] = None ,
96+ remove_from_body : Optional [list [str ]] = None ,
8897 ):
8998 super ().__init__ (type_ = "openai_http" )
9099 self ._target = target or settings .openai .base_url
@@ -122,6 +131,7 @@ def __init__(
122131 )
123132 self .extra_query = extra_query
124133 self .extra_body = extra_body
134+ self .remove_from_body = remove_from_body
125135 self ._async_client : Optional [httpx .AsyncClient ] = None
126136
127137 @property
@@ -253,9 +263,8 @@ async def text_completions( # type: ignore[override]
253263
254264 headers = self ._headers ()
255265 params = self ._params (TEXT_COMPLETIONS )
256- body = self ._body (TEXT_COMPLETIONS )
257266 payload = self ._completions_payload (
258- body = body ,
267+ endpoint_type = TEXT_COMPLETIONS ,
259268 orig_kwargs = kwargs ,
260269 max_output_tokens = output_token_count ,
261270 prompt = prompt ,
@@ -330,12 +339,11 @@ async def chat_completions( # type: ignore[override]
330339 logger .debug ("{} invocation with args: {}" , self .__class__ .__name__ , locals ())
331340 headers = self ._headers ()
332341 params = self ._params (CHAT_COMPLETIONS )
333- body = self ._body (CHAT_COMPLETIONS )
334342 messages = (
335343 content if raw_content else self ._create_chat_messages (content = content )
336344 )
337345 payload = self ._completions_payload (
338- body = body ,
346+ endpoint_type = CHAT_COMPLETIONS ,
339347 orig_kwargs = kwargs ,
340348 max_output_tokens = output_token_count ,
341349 messages = messages ,
@@ -411,7 +419,7 @@ def _params(self, endpoint_type: EndpointType) -> dict[str, str]:
411419
412420 return self .extra_query
413421
414- def _body (self , endpoint_type : EndpointType ) -> dict [str , str ]:
422+ def _extra_body (self , endpoint_type : EndpointType ) -> dict [str , Any ]:
415423 if self .extra_body is None :
416424 return {}
417425
@@ -426,12 +434,12 @@ def _body(self, endpoint_type: EndpointType) -> dict[str, str]:
426434
427435 def _completions_payload (
428436 self ,
429- body : Optional [ dict ] ,
437+ endpoint_type : EndpointType ,
430438 orig_kwargs : Optional [dict ],
431439 max_output_tokens : Optional [int ],
432440 ** kwargs ,
433441 ) -> dict :
434- payload = body or {}
442+ payload = self . _extra_body ( endpoint_type )
435443 payload .update (orig_kwargs or {})
436444 payload .update (kwargs )
437445 payload ["model" ] = self .model
@@ -455,6 +463,10 @@ def _completions_payload(
455463 payload ["stop" ] = None
456464 payload ["ignore_eos" ] = True
457465
466+ if self .remove_from_body :
467+ for key in self .remove_from_body :
468+ payload .pop (key , None )
469+
458470 return payload
459471
460472 @staticmethod
0 commit comments