|
6 | 6 | import strands |
7 | 7 | import strands.event_loop |
8 | 8 | from strands.types._events import ModelStopReason, TypedEvent |
9 | | -from strands.types.content import Message |
| 9 | +from strands.types.content import Message, Messages |
10 | 10 | from strands.types.streaming import ( |
11 | 11 | ContentBlockDeltaEvent, |
12 | 12 | ContentBlockStartEvent, |
@@ -54,6 +54,59 @@ def test_remove_blank_messages_content_text(messages, exp_result): |
54 | 54 | assert tru_result == exp_result |
55 | 55 |
|
56 | 56 |
|
| 57 | +@pytest.mark.parametrize( |
| 58 | + ("messages", "exp_result"), |
| 59 | + [ |
| 60 | + pytest.param( |
| 61 | + [ |
| 62 | + {"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}, {"toolUse": {"name": "a_name"}}]}, |
| 63 | + {"role": "assistant", "content": [{"text": ""}, {"toolUse": {"name": "a_name"}}]}, |
| 64 | + {"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]}, |
| 65 | + {"role": "assistant", "content": []}, |
| 66 | + {"role": "assistant"}, |
| 67 | + {"role": "user", "content": [{"text": " \n"}]}, |
| 68 | + ], |
| 69 | + [ |
| 70 | + {"role": "assistant", "content": [{"text": "a"}, {"toolUse": {"name": "a_name"}}]}, |
| 71 | + {"role": "assistant", "content": [{"toolUse": {"name": "a_name"}}]}, |
| 72 | + {"role": "assistant", "content": [{"text": "a"}, {"text": "[blank text]"}]}, |
| 73 | + {"role": "assistant", "content": [{"text": "[blank text]"}]}, |
| 74 | + {"role": "assistant"}, |
| 75 | + {"role": "user", "content": [{"text": " \n"}]}, |
| 76 | + ], |
| 77 | + id="blank messages", |
| 78 | + ), |
| 79 | + pytest.param( |
| 80 | + [], |
| 81 | + [], |
| 82 | + id="empty messages", |
| 83 | + ), |
| 84 | + pytest.param( |
| 85 | + [ |
| 86 | + {"role": "assistant", "content": [{"toolUse": {"name": "invalid tool"}}]}, |
| 87 | + ], |
| 88 | + [ |
| 89 | + {"role": "assistant", "content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}]}, |
| 90 | + ], |
| 91 | + id="invalid tool name", |
| 92 | + ), |
| 93 | + pytest.param( |
| 94 | + [ |
| 95 | + {"role": "assistant", "content": [{"toolUse": {}}]}, |
| 96 | + ], |
| 97 | + [ |
| 98 | + {"role": "assistant", "content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}]}, |
| 99 | + ], |
| 100 | + id="missing tool name", |
| 101 | + ), |
| 102 | + ], |
| 103 | +) |
| 104 | +def test_normalize_blank_messages_content_text(messages, exp_result): |
| 105 | + tru_result = strands.event_loop.streaming._normalize_messages(messages) |
| 106 | + |
| 107 | + assert tru_result == exp_result |
| 108 | + |
| 109 | + |
57 | 110 | def test_handle_message_start(): |
58 | 111 | event: MessageStartEvent = {"role": "test"} |
59 | 112 |
|
@@ -797,3 +850,43 @@ async def test_stream_messages(agenerator, alist): |
797 | 850 | # Ensure that we're getting typed events coming out of process_stream |
798 | 851 | non_typed_events = [event for event in tru_events if not isinstance(event, TypedEvent)] |
799 | 852 | assert non_typed_events == [] |
| 853 | + |
| 854 | + |
| 855 | +@pytest.mark.asyncio |
| 856 | +async def test_stream_messages_normalizes_messages(agenerator, alist): |
| 857 | + mock_model = unittest.mock.MagicMock() |
| 858 | + mock_model.stream.return_value = agenerator( |
| 859 | + [ |
| 860 | + {"contentBlockDelta": {"delta": {"text": "test"}}}, |
| 861 | + {"contentBlockStop": {}}, |
| 862 | + ] |
| 863 | + ) |
| 864 | + |
| 865 | + messages: Messages = [ |
| 866 | + # blank text |
| 867 | + {"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}, {"toolUse": {"name": "a_name"}}]}, |
| 868 | + {"role": "assistant", "content": [{"text": ""}, {"toolUse": {"name": "a_name"}}]}, |
| 869 | + {"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]}, |
| 870 | + # Invalid names |
| 871 | + {"role": "assistant", "content": [{"toolUse": {"name": "invalid name"}}]}, |
| 872 | + {"role": "assistant", "content": [{"toolUse": {}}]}, |
| 873 | + ] |
| 874 | + |
| 875 | + await alist( |
| 876 | + strands.event_loop.streaming.stream_messages( |
| 877 | + mock_model, |
| 878 | + system_prompt="test prompt", |
| 879 | + messages=messages, |
| 880 | + tool_specs=None, |
| 881 | + ) |
| 882 | + ) |
| 883 | + |
| 884 | + assert mock_model.stream.call_args[0][0] == [ |
| 885 | + # blank text |
| 886 | + {"content": [{"text": "a"}, {"toolUse": {"name": "a_name"}}], "role": "assistant"}, |
| 887 | + {"content": [{"toolUse": {"name": "a_name"}}], "role": "assistant"}, |
| 888 | + {"content": [{"text": "a"}, {"text": "[blank text]"}], "role": "assistant"}, |
| 889 | + # Invalid names |
| 890 | + {"content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}], "role": "assistant"}, |
| 891 | + {"content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}], "role": "assistant"}, |
| 892 | + ] |
0 commit comments