88
99
1010@pytest .fixture
11- def litellm_client_cls ():
12- with unittest .mock .patch .object (strands .models .litellm .litellm , "LiteLLM " ) as mock_client_cls :
13- yield mock_client_cls
11+ def litellm_acompletion ():
12+ with unittest .mock .patch .object (strands .models .litellm .litellm , "acompletion " ) as mock_acompletion :
13+ yield mock_acompletion
1414
1515
1616@pytest .fixture
17- def litellm_client ( litellm_client_cls ):
18- return litellm_client_cls . return_value
17+ def api_key ( ):
18+ return "a1"
1919
2020
2121@pytest .fixture
@@ -24,10 +24,10 @@ def model_id():
2424
2525
2626@pytest .fixture
27- def model (litellm_client , model_id ):
28- _ = litellm_client
27+ def model (litellm_acompletion , api_key , model_id ):
28+ _ = litellm_acompletion
2929
30- return LiteLLMModel (model_id = model_id )
30+ return LiteLLMModel (client_args = { "api_key" : api_key }, model_id = model_id )
3131
3232
3333@pytest .fixture
@@ -49,17 +49,6 @@ class TestOutputModel(pydantic.BaseModel):
4949 return TestOutputModel
5050
5151
52- def test__init__ (litellm_client_cls , model_id ):
53- model = LiteLLMModel ({"api_key" : "k1" }, model_id = model_id , params = {"max_tokens" : 1 })
54-
55- tru_config = model .get_config ()
56- exp_config = {"model_id" : "m1" , "params" : {"max_tokens" : 1 }}
57-
58- assert tru_config == exp_config
59-
60- litellm_client_cls .assert_called_once_with (api_key = "k1" )
61-
62-
6352def test_update_config (model , model_id ):
6453 model .update_config (model_id = model_id )
6554
@@ -116,7 +105,7 @@ def test_format_request_message_content(content, exp_result):
116105
117106
118107@pytest .mark .asyncio
119- async def test_stream (litellm_client , model , alist ):
108+ async def test_stream (litellm_acompletion , api_key , model_id , model , agenerator , alist ):
120109 mock_tool_call_1_part_1 = unittest .mock .Mock (index = 0 )
121110 mock_tool_call_2_part_1 = unittest .mock .Mock (index = 1 )
122111 mock_delta_1 = unittest .mock .Mock (
@@ -148,8 +137,8 @@ async def test_stream(litellm_client, model, alist):
148137 mock_event_5 = unittest .mock .Mock (choices = [unittest .mock .Mock (finish_reason = "tool_calls" , delta = mock_delta_5 )])
149138 mock_event_6 = unittest .mock .Mock ()
150139
151- litellm_client . chat . completions . create . return_value = iter (
152- [mock_event_1 , mock_event_2 , mock_event_3 , mock_event_4 , mock_event_5 , mock_event_6 ]
140+ litellm_acompletion . side_effect = unittest . mock . AsyncMock (
141+ return_value = agenerator ( [mock_event_1 , mock_event_2 , mock_event_3 , mock_event_4 , mock_event_5 , mock_event_6 ])
153142 )
154143
155144 messages = [{"role" : "user" , "content" : [{"type" : "text" , "text" : "calculate 2+2" }]}]
@@ -196,18 +185,20 @@ async def test_stream(litellm_client, model, alist):
196185 ]
197186
198187 assert tru_events == exp_events
188+
199189 expected_request = {
200- "model" : "m1" ,
190+ "api_key" : api_key ,
191+ "model" : model_id ,
201192 "messages" : [{"role" : "user" , "content" : [{"text" : "calculate 2+2" , "type" : "text" }]}],
202193 "stream" : True ,
203194 "stream_options" : {"include_usage" : True },
204195 "tools" : [],
205196 }
206- litellm_client . chat . completions . create .assert_called_once_with (** expected_request )
197+ litellm_acompletion .assert_called_once_with (** expected_request )
207198
208199
209200@pytest .mark .asyncio
210- async def test_structured_output (litellm_client , model , test_output_model_cls , alist ):
201+ async def test_structured_output (litellm_acompletion , model , test_output_model_cls , alist ):
211202 messages = [{"role" : "user" , "content" : [{"text" : "Generate a person" }]}]
212203
213204 mock_choice = unittest .mock .Mock ()
@@ -216,7 +207,7 @@ async def test_structured_output(litellm_client, model, test_output_model_cls, a
216207 mock_response = unittest .mock .Mock ()
217208 mock_response .choices = [mock_choice ]
218209
219- litellm_client . chat . completions . create . return_value = mock_response
210+ litellm_acompletion . side_effect = unittest . mock . AsyncMock ( return_value = mock_response )
220211
221212 with unittest .mock .patch .object (strands .models .litellm , "supports_response_schema" , return_value = True ):
222213 stream = model .structured_output (test_output_model_cls , messages )
0 commit comments