66from http import HTTPStatus
77
88import annotated_types
9- import httpx
10- import httpx_sse
9+ import niquests
1110import pydantic
1211import typing_extensions
1312
2019 OutOfTokensOrSymbolsError ,
2120 UserMessage ,
2221)
23- from any_llm_client .http import get_http_client_from_kwargs , make_http_request , make_streaming_http_request
22+ from any_llm_client .http import HttpClient , HttpStatusError
2423from any_llm_client .retry import RequestRetryConfig
24+ from any_llm_client .sse import parse_sse_events
2525
2626
2727OPENAI_AUTH_TOKEN_ENV_NAME : typing .Final = "ANY_LLM_CLIENT_OPENAI_AUTH_TOKEN"
@@ -99,31 +99,34 @@ def _make_user_assistant_alternate_messages(
9999 yield ChatCompletionsMessage (role = current_message_role , content = "\n \n " .join (current_message_content_chunks ))
100100
101101
102- def _handle_status_error (* , status_code : int , content : bytes ) -> typing .NoReturn :
103- if status_code == HTTPStatus .BAD_REQUEST and b"Please reduce the length of the messages" in content : # vLLM
104- raise OutOfTokensOrSymbolsError (response_content = content )
105- raise LLMError (response_content = content )
102+ def _handle_status_error (error : HttpStatusError ) -> typing .NoReturn :
103+ if (
104+ error .status_code == HTTPStatus .BAD_REQUEST and b"Please reduce the length of the messages" in error .content
105+ ): # vLLM
106+ raise OutOfTokensOrSymbolsError (response_content = error .content )
107+ raise LLMError (response_content = error .content )
106108
107109
108110@dataclasses .dataclass (slots = True , init = False )
109111class OpenAIClient (LLMClient ):
110112 config : OpenAIConfig
111- httpx_client : httpx . AsyncClient
113+ http_client : HttpClient
112114 request_retry : RequestRetryConfig
113115
114116 def __init__ (
115117 self ,
116118 config : OpenAIConfig ,
117119 * ,
118120 request_retry : RequestRetryConfig | None = None ,
119- ** httpx_kwargs : typing .Any , # noqa: ANN401
121+ ** niquests_kwargs : typing .Any , # noqa: ANN401
120122 ) -> None :
121123 self .config = config
122- self .request_retry = request_retry or RequestRetryConfig ()
123- self .httpx_client = get_http_client_from_kwargs (httpx_kwargs )
124+ self .http_client = HttpClient (
125+ request_retry = request_retry or RequestRetryConfig (), niquests_kwargs = niquests_kwargs
126+ )
124127
125- def _build_request (self , payload : dict [str , typing .Any ]) -> httpx .Request :
126- return self . httpx_client . build_request (
128+ def _build_request (self , payload : dict [str , typing .Any ]) -> niquests .Request :
129+ return niquests . Request (
127130 method = "POST" ,
128131 url = str (self .config .url ),
129132 json = payload ,
@@ -152,24 +155,17 @@ async def request_llm_message(
152155 ** extra or {},
153156 ).model_dump (mode = "json" )
154157 try :
155- response : typing .Final = await make_http_request (
156- httpx_client = self .httpx_client ,
157- request_retry = self .request_retry ,
158- build_request = lambda : self ._build_request (payload ),
159- )
160- except httpx .HTTPStatusError as exception :
161- _handle_status_error (status_code = exception .response .status_code , content = exception .response .content )
162- try :
163- return ChatCompletionsNotStreamingResponse .model_validate_json (response .content ).choices [0 ].message .content
164- finally :
165- await response .aclose ()
158+ response : typing .Final = await self .http_client .request (self ._build_request (payload ))
159+ except HttpStatusError as exception :
160+ _handle_status_error (exception )
161+ return ChatCompletionsNotStreamingResponse .model_validate_json (response ).choices [0 ].message .content
166162
167- async def _iter_partial_responses (self , response : httpx . Response ) -> typing .AsyncIterable [str ]:
163+ async def _iter_partial_responses (self , response : typing . AsyncIterable [ bytes ] ) -> typing .AsyncIterable [str ]:
168164 text_chunks : typing .Final = []
169- async for event in httpx_sse . EventSource (response ). aiter_sse ( ):
170- if event .data == "[DONE]" :
165+ async for one_event in parse_sse_events (response ):
166+ if one_event .data == "[DONE]" :
171167 break
172- validated_response = ChatCompletionsStreamingEvent .model_validate_json (event .data )
168+ validated_response = ChatCompletionsStreamingEvent .model_validate_json (one_event .data )
173169 if not (one_chunk := validated_response .choices [0 ].delta .content ):
174170 continue
175171 text_chunks .append (one_chunk )
@@ -187,19 +183,13 @@ async def stream_llm_partial_messages(
187183 ** extra or {},
188184 ).model_dump (mode = "json" )
189185 try :
190- async with make_streaming_http_request (
191- httpx_client = self .httpx_client ,
192- request_retry = self .request_retry ,
193- build_request = lambda : self ._build_request (payload ),
194- ) as response :
186+ async with self .http_client .stream (request = self ._build_request (payload )) as response :
195187 yield self ._iter_partial_responses (response )
196- except httpx .HTTPStatusError as exception :
197- content : typing .Final = await exception .response .aread ()
198- await exception .response .aclose ()
199- _handle_status_error (status_code = exception .response .status_code , content = content )
188+ except HttpStatusError as exception :
189+ _handle_status_error (exception )
200190
201191 async def __aenter__ (self ) -> typing_extensions .Self :
202- await self .httpx_client .__aenter__ ()
192+ await self .http_client .__aenter__ ()
203193 return self
204194
205195 async def __aexit__ (
@@ -208,4 +198,4 @@ async def __aexit__(
208198 exc_value : BaseException | None ,
209199 traceback : types .TracebackType | None ,
210200 ) -> None :
211- await self .httpx_client .__aexit__ (exc_type = exc_type , exc_value = exc_value , traceback = traceback )
201+ await self .http_client .__aexit__ (exc_type = exc_type , exc_value = exc_value , traceback = traceback )
0 commit comments