1717)
1818from guidellm .config import settings
1919
20- __all__ = ["CHAT_COMPLETIONS_PATH" , "TEXT_COMPLETIONS_PATH" , "OpenAIHTTPBackend" ]
20+ __all__ = [
21+ "CHAT_COMPLETIONS" ,
22+ "CHAT_COMPLETIONS_PATH" ,
23+ "MODELS" ,
24+ "TEXT_COMPLETIONS" ,
25+ "TEXT_COMPLETIONS_PATH" ,
26+ "OpenAIHTTPBackend" ,
27+ ]
2128
2229
2330TEXT_COMPLETIONS_PATH = "/v1/completions"
2431CHAT_COMPLETIONS_PATH = "/v1/chat/completions"
2532
33+ EndpointType = Literal ["chat_completions" , "models" , "text_completions" ]
34+ CHAT_COMPLETIONS : EndpointType = "chat_completions"
35+ MODELS : EndpointType = "models"
36+ TEXT_COMPLETIONS : EndpointType = "text_completions"
37+
2638
2739@Backend .register ("openai_http" )
2840class OpenAIHTTPBackend (Backend ):
@@ -53,6 +65,11 @@ class OpenAIHTTPBackend(Backend):
5365 If not provided, the default value from settings is used.
5466 :param max_output_tokens: The maximum number of tokens to request for completions.
5567 If not provided, the default maximum tokens provided from settings is used.
68+ :param extra_query: Query parameters to include in requests to the OpenAI server.
69+ If "chat_completions", "models", or "text_completions" are included as keys,
70+ the values of these keys will be used as the parameters for the respective
71+ endpoint.
72+ If not provided, no extra query parameters are added.
5673 """
5774
5875 def __init__ (
@@ -66,6 +83,7 @@ def __init__(
6683 http2 : Optional [bool ] = True ,
6784 follow_redirects : Optional [bool ] = None ,
6885 max_output_tokens : Optional [int ] = None ,
86+ extra_query : Optional [dict ] = None ,
6987 ):
7088 super ().__init__ (type_ = "openai_http" )
7189 self ._target = target or settings .openai .base_url
@@ -101,6 +119,7 @@ def __init__(
101119 if max_output_tokens is not None
102120 else settings .openai .max_output_tokens
103121 )
122+ self .extra_query = extra_query
104123 self ._async_client : Optional [httpx .AsyncClient ] = None
105124
106125 @property
@@ -174,7 +193,10 @@ async def available_models(self) -> list[str]:
174193 """
175194 target = f"{ self .target } /v1/models"
176195 headers = self ._headers ()
177- response = await self ._get_async_client ().get (target , headers = headers )
196+ params = self ._params (MODELS )
197+ response = await self ._get_async_client ().get (
198+ target , headers = headers , params = params
199+ )
178200 response .raise_for_status ()
179201
180202 models = []
@@ -219,6 +241,7 @@ async def text_completions( # type: ignore[override]
219241 )
220242
221243 headers = self ._headers ()
244+ params = self ._params (TEXT_COMPLETIONS )
222245 payload = self ._completions_payload (
223246 orig_kwargs = kwargs ,
224247 max_output_tokens = output_token_count ,
@@ -232,14 +255,16 @@ async def text_completions( # type: ignore[override]
232255 request_prompt_tokens = prompt_token_count ,
233256 request_output_tokens = output_token_count ,
234257 headers = headers ,
258+ params = params ,
235259 payload = payload ,
236260 ):
237261 yield resp
238262 except Exception as ex :
239263 logger .error (
240- "{} request with headers: {} and payload: {} failed: {}" ,
264+ "{} request with headers: {} and params: {} and payload: {} failed: {}" ,
241265 self .__class__ .__name__ ,
242266 headers ,
267+ params ,
243268 payload ,
244269 ex ,
245270 )
@@ -291,6 +316,7 @@ async def chat_completions( # type: ignore[override]
291316 """
292317 logger .debug ("{} invocation with args: {}" , self .__class__ .__name__ , locals ())
293318 headers = self ._headers ()
319+ params = self ._params (CHAT_COMPLETIONS )
294320 messages = (
295321 content if raw_content else self ._create_chat_messages (content = content )
296322 )
@@ -307,14 +333,16 @@ async def chat_completions( # type: ignore[override]
307333 request_prompt_tokens = prompt_token_count ,
308334 request_output_tokens = output_token_count ,
309335 headers = headers ,
336+ params = params ,
310337 payload = payload ,
311338 ):
312339 yield resp
313340 except Exception as ex :
314341 logger .error (
315- "{} request with headers: {} and payload: {} failed: {}" ,
342+ "{} request with headers: {} and params: {} and payload: {} failed: {}" ,
316343 self .__class__ .__name__ ,
317344 headers ,
345+ params ,
318346 payload ,
319347 ex ,
320348 )
@@ -355,6 +383,19 @@ def _headers(self) -> dict[str, str]:
355383
356384 return headers
357385
386+ def _params (self , endpoint_type : EndpointType ) -> dict [str , str ]:
387+ if self .extra_query is None :
388+ return {}
389+
390+ if (
391+ CHAT_COMPLETIONS in self .extra_query
392+ or MODELS in self .extra_query
393+ or TEXT_COMPLETIONS in self .extra_query
394+ ):
395+ return self .extra_query .get (endpoint_type , {})
396+
397+ return self .extra_query
398+
358399 def _completions_payload (
359400 self , orig_kwargs : Optional [dict ], max_output_tokens : Optional [int ], ** kwargs
360401 ) -> dict :
@@ -451,8 +492,9 @@ async def _iterative_completions_request(
451492 request_id : Optional [str ],
452493 request_prompt_tokens : Optional [int ],
453494 request_output_tokens : Optional [int ],
454- headers : dict ,
455- payload : dict ,
495+ headers : dict [str , str ],
496+ params : dict [str , str ],
497+ payload : dict [str , Any ],
456498 ) -> AsyncGenerator [Union [StreamingTextResponse , ResponseSummary ], None ]:
457499 if type_ == "text_completions" :
458500 target = f"{ self .target } { TEXT_COMPLETIONS_PATH } "
@@ -463,14 +505,16 @@ async def _iterative_completions_request(
463505
464506 logger .info (
465507 "{} making request: {} to target: {} using http2: {} following "
466- "redirects: {} for timeout: {} with headers: {} and payload: {}" ,
508+ "redirects: {} for timeout: {} with headers: {} and params: {} and " ,
509+ "payload: {}" ,
467510 self .__class__ .__name__ ,
468511 request_id ,
469512 target ,
470513 self .http2 ,
471514 self .follow_redirects ,
472515 self .timeout ,
473516 headers ,
517+ params ,
474518 payload ,
475519 )
476520
@@ -498,7 +542,7 @@ async def _iterative_completions_request(
498542 start_time = time .time ()
499543
500544 async with self ._get_async_client ().stream (
501- "POST" , target , headers = headers , json = payload
545+ "POST" , target , headers = headers , params = params , json = payload
502546 ) as stream :
503547 stream .raise_for_status ()
504548
@@ -542,10 +586,12 @@ async def _iterative_completions_request(
542586 response_output_count = usage ["output" ]
543587
544588 logger .info (
545- "{} request: {} with headers: {} and payload: {} completed with: {}" ,
589+ "{} request: {} with headers: {} and params: {} and payload: {} completed"
590+ "with: {}" ,
546591 self .__class__ .__name__ ,
547592 request_id ,
548593 headers ,
594+ params ,
549595 payload ,
550596 response_value ,
551597 )
@@ -555,6 +601,7 @@ async def _iterative_completions_request(
555601 request_args = RequestArgs (
556602 target = target ,
557603 headers = headers ,
604+ params = params ,
558605 payload = payload ,
559606 timeout = self .timeout ,
560607 http2 = self .http2 ,
0 commit comments