Skip to content

Commit a36255d

Browse files
authored
fix: Invoke callback handler for structured_output (#857)
In the switch to typed_events, the case of structured_output invoking the callback handler was missed, resulting in issue #831; this restores the old behavior/fixes backwards compatibility Co-authored-by: Mackenzie Zastrow <zastrowm@users.noreply.github.com>
1 parent 406458d commit a36255d

File tree

2 files changed

+136
-2
lines changed

2 files changed

+136
-2
lines changed

src/strands/agent/agent.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -514,8 +514,11 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
514514
)
515515
events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt)
516516
async for event in events:
517-
if "callback" in event:
518-
self.callback_handler(**cast(dict, event["callback"]))
517+
if isinstance(event, TypedEvent):
518+
event.prepare(invocation_state={})
519+
if event.is_callback_event:
520+
self.callback_handler(**event.as_dict())
521+
519522
structured_output_span.add_event(
520523
"gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())}
521524
)

tests/strands/agent/hooks/test_agent_events.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from unittest.mock import ANY, MagicMock, call
44

55
import pytest
6+
from pydantic import BaseModel
67

78
import strands
89
from strands import Agent
910
from strands.agent import AgentResult
11+
from strands.models import BedrockModel
1012
from strands.types._events import TypedEvent
1113
from strands.types.exceptions import ModelThrottledException
1214
from tests.fixtures.mocked_model_provider import MockedModelProvider
@@ -518,3 +520,132 @@ async def test_event_loop_cycle_text_response_throttling_early_end(
518520
# Ensure that all events coming out of the agent are *not* typed events
519521
typed_events = [event for event in tru_events if isinstance(event, TypedEvent)]
520522
assert typed_events == []
523+
524+
525+
@pytest.mark.asyncio
526+
async def test_structured_output(agenerator):
527+
# we use bedrock here as it uses the tool implementation
528+
model = BedrockModel()
529+
model.stream = MagicMock()
530+
model.stream.return_value = agenerator(
531+
[
532+
{
533+
"contentBlockStart": {
534+
"start": {"toolUse": {"toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", "name": "Person"}},
535+
"contentBlockIndex": 0,
536+
}
537+
},
538+
{"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 0}},
539+
{"contentBlockDelta": {"delta": {"toolUse": {"input": '{"na'}}, "contentBlockIndex": 0}},
540+
{"contentBlockDelta": {"delta": {"toolUse": {"input": 'me"'}}, "contentBlockIndex": 0}},
541+
{"contentBlockDelta": {"delta": {"toolUse": {"input": ': "J'}}, "contentBlockIndex": 0}},
542+
{"contentBlockDelta": {"delta": {"toolUse": {"input": 'ohn"'}}, "contentBlockIndex": 0}},
543+
{"contentBlockDelta": {"delta": {"toolUse": {"input": ', "age": 3'}}, "contentBlockIndex": 0}},
544+
{"contentBlockDelta": {"delta": {"toolUse": {"input": "1}"}}, "contentBlockIndex": 0}},
545+
{"contentBlockStop": {"contentBlockIndex": 0}},
546+
{"messageStop": {"stopReason": "tool_use"}},
547+
{
548+
"metadata": {
549+
"usage": {"inputTokens": 407, "outputTokens": 53, "totalTokens": 460},
550+
"metrics": {"latencyMs": 1572},
551+
}
552+
},
553+
]
554+
)
555+
556+
mock_callback = unittest.mock.Mock()
557+
agent = Agent(model=model, callback_handler=mock_callback)
558+
559+
class Person(BaseModel):
560+
name: str
561+
age: float
562+
563+
await agent.structured_output_async(Person, "John is 31")
564+
565+
exp_events = [
566+
{
567+
"event": {
568+
"contentBlockStart": {
569+
"start": {"toolUse": {"toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", "name": "Person"}},
570+
"contentBlockIndex": 0,
571+
}
572+
}
573+
},
574+
{"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 0}}},
575+
{
576+
"delta": {"toolUse": {"input": ""}},
577+
"current_tool_use": {
578+
"toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ",
579+
"name": "Person",
580+
"input": {"name": "John", "age": 31},
581+
},
582+
},
583+
{"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"na'}}, "contentBlockIndex": 0}}},
584+
{
585+
"delta": {"toolUse": {"input": '{"na'}},
586+
"current_tool_use": {
587+
"toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ",
588+
"name": "Person",
589+
"input": {"name": "John", "age": 31},
590+
},
591+
},
592+
{"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": 'me"'}}, "contentBlockIndex": 0}}},
593+
{
594+
"delta": {"toolUse": {"input": 'me"'}},
595+
"current_tool_use": {
596+
"toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ",
597+
"name": "Person",
598+
"input": {"name": "John", "age": 31},
599+
},
600+
},
601+
{"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": ': "J'}}, "contentBlockIndex": 0}}},
602+
{
603+
"delta": {"toolUse": {"input": ': "J'}},
604+
"current_tool_use": {
605+
"toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ",
606+
"name": "Person",
607+
"input": {"name": "John", "age": 31},
608+
},
609+
},
610+
{"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": 'ohn"'}}, "contentBlockIndex": 0}}},
611+
{
612+
"delta": {"toolUse": {"input": 'ohn"'}},
613+
"current_tool_use": {
614+
"toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ",
615+
"name": "Person",
616+
"input": {"name": "John", "age": 31},
617+
},
618+
},
619+
{"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": ', "age": 3'}}, "contentBlockIndex": 0}}},
620+
{
621+
"delta": {"toolUse": {"input": ', "age": 3'}},
622+
"current_tool_use": {
623+
"toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ",
624+
"name": "Person",
625+
"input": {"name": "John", "age": 31},
626+
},
627+
},
628+
{"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": "1}"}}, "contentBlockIndex": 0}}},
629+
{
630+
"delta": {"toolUse": {"input": "1}"}},
631+
"current_tool_use": {
632+
"toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ",
633+
"name": "Person",
634+
"input": {"name": "John", "age": 31},
635+
},
636+
},
637+
{"event": {"contentBlockStop": {"contentBlockIndex": 0}}},
638+
{"event": {"messageStop": {"stopReason": "tool_use"}}},
639+
{
640+
"event": {
641+
"metadata": {
642+
"usage": {"inputTokens": 407, "outputTokens": 53, "totalTokens": 460},
643+
"metrics": {"latencyMs": 1572},
644+
}
645+
}
646+
},
647+
]
648+
649+
exp_calls = [call(**event) for event in exp_events]
650+
act_calls = mock_callback.call_args_list
651+
assert act_calls == exp_calls

0 commit comments

Comments
 (0)