|
1 | 1 | import unittest.mock |
| 2 | +from unittest.mock import MagicMock |
2 | 3 |
|
3 | 4 | import pytest |
4 | 5 |
|
@@ -39,7 +40,6 @@ async def test_executor_stream_yields_result( |
39 | 40 |
|
40 | 41 | tru_events = await alist(stream) |
41 | 42 | exp_events = [ |
42 | | - ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), |
43 | 43 | ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), |
44 | 44 | ] |
45 | 45 | assert tru_events == exp_events |
@@ -67,6 +67,76 @@ async def test_executor_stream_yields_result( |
67 | 67 | assert tru_hook_events == exp_hook_events |
68 | 68 |
|
69 | 69 |
|
| 70 | +@pytest.mark.asyncio |
| 71 | +async def test_executor_stream_wraps_results( |
| 72 | + executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist, agenerator |
| 73 | +): |
| 74 | + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} |
| 75 | + stream = executor._stream(agent, tool_use, tool_results, invocation_state) |
| 76 | + |
| 77 | + weather_tool.stream = MagicMock() |
| 78 | + weather_tool.stream.return_value = agenerator( |
| 79 | + ["value 1", {"nested": True}, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}] |
| 80 | + ) |
| 81 | + |
| 82 | + tru_events = await alist(stream) |
| 83 | + exp_events = [ |
| 84 | + ToolStreamEvent(tool_use, "value 1"), |
| 85 | + ToolStreamEvent(tool_use, {"nested": True}), |
| 86 | + ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), |
| 87 | + ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), |
| 88 | + ] |
| 89 | + assert tru_events == exp_events |
| 90 | + |
| 91 | + |
| 92 | +@pytest.mark.asyncio |
| 93 | +async def test_executor_stream_passes_through_typed_events( |
| 94 | + executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist, agenerator |
| 95 | +): |
| 96 | + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} |
| 97 | + stream = executor._stream(agent, tool_use, tool_results, invocation_state) |
| 98 | + |
| 99 | + weather_tool.stream = MagicMock() |
| 100 | + event_1 = ToolStreamEvent(tool_use, "value 1") |
| 101 | + event_2 = ToolStreamEvent(tool_use, {"nested": True}) |
| 102 | + event_3 = ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}) |
| 103 | + weather_tool.stream.return_value = agenerator( |
| 104 | + [ |
| 105 | + event_1, |
| 106 | + event_2, |
| 107 | + event_3, |
| 108 | + ] |
| 109 | + ) |
| 110 | + |
| 111 | + tru_events = await alist(stream) |
| 112 | + assert tru_events[0] is event_1 |
| 113 | + assert tru_events[1] is event_2 |
| 114 | + |
| 115 | + # ToolResults are not passed through directly, they're unwrapped then wraped again |
| 116 | + assert tru_events[2] == event_3 |
| 117 | + |
| 118 | + |
| 119 | +@pytest.mark.asyncio |
| 120 | +async def test_executor_stream_wraps_stream_events_if_no_result( |
| 121 | + executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist, agenerator |
| 122 | +): |
| 123 | + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} |
| 124 | + stream = executor._stream(agent, tool_use, tool_results, invocation_state) |
| 125 | + |
| 126 | + weather_tool.stream = MagicMock() |
| 127 | + last_event = ToolStreamEvent(tool_use, "value 1") |
| 128 | + # Only ToolResultEvent can be the last value; all others are wrapped in ToolResultEvent |
| 129 | + weather_tool.stream.return_value = agenerator( |
| 130 | + [ |
| 131 | + last_event, |
| 132 | + ] |
| 133 | + ) |
| 134 | + |
| 135 | + tru_events = await alist(stream) |
| 136 | + exp_events = [last_event, ToolResultEvent(last_event)] |
| 137 | + assert tru_events == exp_events |
| 138 | + |
| 139 | + |
70 | 140 | @pytest.mark.asyncio |
71 | 141 | async def test_executor_stream_yields_tool_error( |
72 | 142 | executor, agent, tool_results, invocation_state, hook_events, exception_tool, alist |
@@ -129,7 +199,6 @@ async def test_executor_stream_with_trace( |
129 | 199 |
|
130 | 200 | tru_events = await alist(stream) |
131 | 201 | exp_events = [ |
132 | | - ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), |
133 | 202 | ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), |
134 | 203 | ] |
135 | 204 | assert tru_events == exp_events |
|
0 commit comments