|
3 | 3 | from unittest.mock import ANY, MagicMock, call |
4 | 4 |
|
5 | 5 | import pytest |
| 6 | +from pydantic import BaseModel |
6 | 7 |
|
7 | 8 | import strands |
8 | 9 | from strands import Agent |
9 | 10 | from strands.agent import AgentResult |
| 11 | +from strands.models import BedrockModel |
10 | 12 | from strands.types._events import TypedEvent |
11 | 13 | from strands.types.exceptions import ModelThrottledException |
12 | 14 | from tests.fixtures.mocked_model_provider import MockedModelProvider |
@@ -518,3 +520,132 @@ async def test_event_loop_cycle_text_response_throttling_early_end( |
518 | 520 | # Ensure that all events coming out of the agent are *not* typed events |
519 | 521 | typed_events = [event for event in tru_events if isinstance(event, TypedEvent)] |
520 | 522 | assert typed_events == [] |
| 523 | + |
| 524 | + |
| 525 | +@pytest.mark.asyncio |
| 526 | +async def test_structured_output(agenerator): |
| 527 | + # we use bedrock here as it uses the tool implementation |
| 528 | + model = BedrockModel() |
| 529 | + model.stream = MagicMock() |
| 530 | + model.stream.return_value = agenerator( |
| 531 | + [ |
| 532 | + { |
| 533 | + "contentBlockStart": { |
| 534 | + "start": {"toolUse": {"toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", "name": "Person"}}, |
| 535 | + "contentBlockIndex": 0, |
| 536 | + } |
| 537 | + }, |
| 538 | + {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 0}}, |
| 539 | + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"na'}}, "contentBlockIndex": 0}}, |
| 540 | + {"contentBlockDelta": {"delta": {"toolUse": {"input": 'me"'}}, "contentBlockIndex": 0}}, |
| 541 | + {"contentBlockDelta": {"delta": {"toolUse": {"input": ': "J'}}, "contentBlockIndex": 0}}, |
| 542 | + {"contentBlockDelta": {"delta": {"toolUse": {"input": 'ohn"'}}, "contentBlockIndex": 0}}, |
| 543 | + {"contentBlockDelta": {"delta": {"toolUse": {"input": ', "age": 3'}}, "contentBlockIndex": 0}}, |
| 544 | + {"contentBlockDelta": {"delta": {"toolUse": {"input": "1}"}}, "contentBlockIndex": 0}}, |
| 545 | + {"contentBlockStop": {"contentBlockIndex": 0}}, |
| 546 | + {"messageStop": {"stopReason": "tool_use"}}, |
| 547 | + { |
| 548 | + "metadata": { |
| 549 | + "usage": {"inputTokens": 407, "outputTokens": 53, "totalTokens": 460}, |
| 550 | + "metrics": {"latencyMs": 1572}, |
| 551 | + } |
| 552 | + }, |
| 553 | + ] |
| 554 | + ) |
| 555 | + |
| 556 | + mock_callback = unittest.mock.Mock() |
| 557 | + agent = Agent(model=model, callback_handler=mock_callback) |
| 558 | + |
| 559 | + class Person(BaseModel): |
| 560 | + name: str |
| 561 | + age: float |
| 562 | + |
| 563 | + await agent.structured_output_async(Person, "John is 31") |
| 564 | + |
| 565 | + exp_events = [ |
| 566 | + { |
| 567 | + "event": { |
| 568 | + "contentBlockStart": { |
| 569 | + "start": {"toolUse": {"toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", "name": "Person"}}, |
| 570 | + "contentBlockIndex": 0, |
| 571 | + } |
| 572 | + } |
| 573 | + }, |
| 574 | + {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 0}}}, |
| 575 | + { |
| 576 | + "delta": {"toolUse": {"input": ""}}, |
| 577 | + "current_tool_use": { |
| 578 | + "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", |
| 579 | + "name": "Person", |
| 580 | + "input": {"name": "John", "age": 31}, |
| 581 | + }, |
| 582 | + }, |
| 583 | + {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"na'}}, "contentBlockIndex": 0}}}, |
| 584 | + { |
| 585 | + "delta": {"toolUse": {"input": '{"na'}}, |
| 586 | + "current_tool_use": { |
| 587 | + "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", |
| 588 | + "name": "Person", |
| 589 | + "input": {"name": "John", "age": 31}, |
| 590 | + }, |
| 591 | + }, |
| 592 | + {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": 'me"'}}, "contentBlockIndex": 0}}}, |
| 593 | + { |
| 594 | + "delta": {"toolUse": {"input": 'me"'}}, |
| 595 | + "current_tool_use": { |
| 596 | + "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", |
| 597 | + "name": "Person", |
| 598 | + "input": {"name": "John", "age": 31}, |
| 599 | + }, |
| 600 | + }, |
| 601 | + {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": ': "J'}}, "contentBlockIndex": 0}}}, |
| 602 | + { |
| 603 | + "delta": {"toolUse": {"input": ': "J'}}, |
| 604 | + "current_tool_use": { |
| 605 | + "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", |
| 606 | + "name": "Person", |
| 607 | + "input": {"name": "John", "age": 31}, |
| 608 | + }, |
| 609 | + }, |
| 610 | + {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": 'ohn"'}}, "contentBlockIndex": 0}}}, |
| 611 | + { |
| 612 | + "delta": {"toolUse": {"input": 'ohn"'}}, |
| 613 | + "current_tool_use": { |
| 614 | + "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", |
| 615 | + "name": "Person", |
| 616 | + "input": {"name": "John", "age": 31}, |
| 617 | + }, |
| 618 | + }, |
| 619 | + {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": ', "age": 3'}}, "contentBlockIndex": 0}}}, |
| 620 | + { |
| 621 | + "delta": {"toolUse": {"input": ', "age": 3'}}, |
| 622 | + "current_tool_use": { |
| 623 | + "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", |
| 624 | + "name": "Person", |
| 625 | + "input": {"name": "John", "age": 31}, |
| 626 | + }, |
| 627 | + }, |
| 628 | + {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": "1}"}}, "contentBlockIndex": 0}}}, |
| 629 | + { |
| 630 | + "delta": {"toolUse": {"input": "1}"}}, |
| 631 | + "current_tool_use": { |
| 632 | + "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", |
| 633 | + "name": "Person", |
| 634 | + "input": {"name": "John", "age": 31}, |
| 635 | + }, |
| 636 | + }, |
| 637 | + {"event": {"contentBlockStop": {"contentBlockIndex": 0}}}, |
| 638 | + {"event": {"messageStop": {"stopReason": "tool_use"}}}, |
| 639 | + { |
| 640 | + "event": { |
| 641 | + "metadata": { |
| 642 | + "usage": {"inputTokens": 407, "outputTokens": 53, "totalTokens": 460}, |
| 643 | + "metrics": {"latencyMs": 1572}, |
| 644 | + } |
| 645 | + } |
| 646 | + }, |
| 647 | + ] |
| 648 | + |
| 649 | + exp_calls = [call(**event) for event in exp_events] |
| 650 | + act_calls = mock_callback.call_args_list |
| 651 | + assert act_calls == exp_calls |
0 commit comments