Skip to content

Commit 001aa93

Browse files
liushang1997zastrowmShang Liu
authored
feat: add support for Bedrock/Anthropic ToolChoice to structured_output (#720)
For structured output so that some providers can force tool calls --------- Co-authored-by: Mackenzie Zastrow <zastrowm@users.noreply.github.com> Co-authored-by: Shang Liu <sshangl@amazon.com>
1 parent 9213bc5 commit 001aa93

22 files changed

+678
-47
lines changed

src/strands/models/_config_validation.py renamed to src/strands/models/_validation.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from typing_extensions import get_type_hints
77

8+
from ..types.tools import ToolChoice
9+
810

911
def validate_config_keys(config_dict: Mapping[str, Any], config_class: Type) -> None:
1012
"""Validate that config keys match the TypedDict fields.
@@ -25,3 +27,16 @@ def validate_config_keys(config_dict: Mapping[str, Any], config_class: Type) ->
2527
f"\nSee https://github.com/strands-agents/sdk-python/issues/815",
2628
stacklevel=4,
2729
)
30+
31+
32+
def warn_on_tool_choice_not_supported(tool_choice: ToolChoice | None) -> None:
33+
"""Emits a warning if a tool choice is provided but not supported by the provider.
34+
35+
Args:
36+
tool_choice: the tool_choice provided to the provider
37+
"""
38+
if tool_choice:
39+
warnings.warn(
40+
"A ToolChoice was provided to this provider but is not supported and will be ignored",
41+
stacklevel=4,
42+
)

src/strands/models/anthropic.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from ..types.content import ContentBlock, Messages
1919
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
2020
from ..types.streaming import StreamEvent
21-
from ..types.tools import ToolSpec
22-
from ._config_validation import validate_config_keys
21+
from ..types.tools import ToolChoice, ToolChoiceToolDict, ToolSpec
22+
from ._validation import validate_config_keys
2323
from .model import Model
2424

2525
logger = logging.getLogger(__name__)
@@ -195,14 +195,19 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]:
195195
return formatted_messages
196196

197197
def format_request(
198-
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
198+
self,
199+
messages: Messages,
200+
tool_specs: Optional[list[ToolSpec]] = None,
201+
system_prompt: Optional[str] = None,
202+
tool_choice: ToolChoice | None = None,
199203
) -> dict[str, Any]:
200204
"""Format an Anthropic streaming request.
201205
202206
Args:
203207
messages: List of message objects to be processed by the model.
204208
tool_specs: List of tool specifications to make available to the model.
205209
system_prompt: System prompt to provide context to the model.
210+
tool_choice: Selection strategy for tool invocation.
206211
207212
Returns:
208213
An Anthropic streaming request.
@@ -223,10 +228,25 @@ def format_request(
223228
}
224229
for tool_spec in tool_specs or []
225230
],
231+
**(self._format_tool_choice(tool_choice)),
226232
**({"system": system_prompt} if system_prompt else {}),
227233
**(self.config.get("params") or {}),
228234
}
229235

236+
@staticmethod
237+
def _format_tool_choice(tool_choice: ToolChoice | None) -> dict:
238+
if tool_choice is None:
239+
return {}
240+
241+
if "any" in tool_choice:
242+
return {"tool_choice": {"type": "any"}}
243+
elif "auto" in tool_choice:
244+
return {"tool_choice": {"type": "auto"}}
245+
elif "tool" in tool_choice:
246+
return {"tool_choice": {"type": "tool", "name": cast(ToolChoiceToolDict, tool_choice)["tool"]["name"]}}
247+
else:
248+
return {}
249+
230250
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
231251
"""Format the Anthropic response events into standardized message chunks.
232252
@@ -350,6 +370,7 @@ async def stream(
350370
messages: Messages,
351371
tool_specs: Optional[list[ToolSpec]] = None,
352372
system_prompt: Optional[str] = None,
373+
tool_choice: ToolChoice | None = None,
353374
**kwargs: Any,
354375
) -> AsyncGenerator[StreamEvent, None]:
355376
"""Stream conversation with the Anthropic model.
@@ -358,6 +379,7 @@ async def stream(
358379
messages: List of message objects to be processed by the model.
359380
tool_specs: List of tool specifications to make available to the model.
360381
system_prompt: System prompt to provide context to the model.
382+
tool_choice: Selection strategy for tool invocation.
361383
**kwargs: Additional keyword arguments for future extensibility.
362384
363385
Yields:
@@ -368,7 +390,7 @@ async def stream(
368390
ModelThrottledException: If the request is throttled by Anthropic.
369391
"""
370392
logger.debug("formatting request")
371-
request = self.format_request(messages, tool_specs, system_prompt)
393+
request = self.format_request(messages, tool_specs, system_prompt, tool_choice)
372394
logger.debug("request=<%s>", request)
373395

374396
logger.debug("invoking model")
@@ -410,7 +432,13 @@ async def structured_output(
410432
"""
411433
tool_spec = convert_pydantic_to_tool_spec(output_model)
412434

413-
response = self.stream(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, **kwargs)
435+
response = self.stream(
436+
messages=prompt,
437+
tool_specs=[tool_spec],
438+
system_prompt=system_prompt,
439+
tool_choice=cast(ToolChoice, {"any": {}}),
440+
**kwargs,
441+
)
414442
async for event in process_stream(response):
415443
yield event
416444

src/strands/models/bedrock.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
ModelThrottledException,
2424
)
2525
from ..types.streaming import CitationsDelta, StreamEvent
26-
from ..types.tools import ToolResult, ToolSpec
27-
from ._config_validation import validate_config_keys
26+
from ..types.tools import ToolChoice, ToolResult, ToolSpec
27+
from ._validation import validate_config_keys
2828
from .model import Model
2929

3030
logger = logging.getLogger(__name__)
@@ -196,13 +196,15 @@ def format_request(
196196
messages: Messages,
197197
tool_specs: Optional[list[ToolSpec]] = None,
198198
system_prompt: Optional[str] = None,
199+
tool_choice: ToolChoice | None = None,
199200
) -> dict[str, Any]:
200201
"""Format a Bedrock converse stream request.
201202
202203
Args:
203204
messages: List of message objects to be processed by the model.
204205
tool_specs: List of tool specifications to make available to the model.
205206
system_prompt: System prompt to provide context to the model.
207+
tool_choice: Selection strategy for tool invocation.
206208
207209
Returns:
208210
A Bedrock converse stream request.
@@ -225,7 +227,7 @@ def format_request(
225227
else []
226228
),
227229
],
228-
"toolChoice": {"auto": {}},
230+
**({"toolChoice": tool_choice if tool_choice else {"auto": {}}}),
229231
}
230232
}
231233
if tool_specs
@@ -417,6 +419,7 @@ async def stream(
417419
messages: Messages,
418420
tool_specs: Optional[list[ToolSpec]] = None,
419421
system_prompt: Optional[str] = None,
422+
tool_choice: ToolChoice | None = None,
420423
**kwargs: Any,
421424
) -> AsyncGenerator[StreamEvent, None]:
422425
"""Stream conversation with the Bedrock model.
@@ -428,6 +431,7 @@ async def stream(
428431
messages: List of message objects to be processed by the model.
429432
tool_specs: List of tool specifications to make available to the model.
430433
system_prompt: System prompt to provide context to the model.
434+
tool_choice: Selection strategy for tool invocation.
431435
**kwargs: Additional keyword arguments for future extensibility.
432436
433437
Yields:
@@ -446,7 +450,7 @@ def callback(event: Optional[StreamEvent] = None) -> None:
446450
loop = asyncio.get_event_loop()
447451
queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue()
448452

449-
thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt)
453+
thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt, tool_choice)
450454
task = asyncio.create_task(thread)
451455

452456
while True:
@@ -464,6 +468,7 @@ def _stream(
464468
messages: Messages,
465469
tool_specs: Optional[list[ToolSpec]] = None,
466470
system_prompt: Optional[str] = None,
471+
tool_choice: ToolChoice | None = None,
467472
) -> None:
468473
"""Stream conversation with the Bedrock model.
469474
@@ -475,14 +480,15 @@ def _stream(
475480
messages: List of message objects to be processed by the model.
476481
tool_specs: List of tool specifications to make available to the model.
477482
system_prompt: System prompt to provide context to the model.
483+
tool_choice: Selection strategy for tool invocation.
478484
479485
Raises:
480486
ContextWindowOverflowException: If the input exceeds the model's context window.
481487
ModelThrottledException: If the model service is throttling requests.
482488
"""
483489
try:
484490
logger.debug("formatting request")
485-
request = self.format_request(messages, tool_specs, system_prompt)
491+
request = self.format_request(messages, tool_specs, system_prompt, tool_choice)
486492
logger.debug("request=<%s>", request)
487493

488494
logger.debug("invoking model")
@@ -739,6 +745,7 @@ async def structured_output(
739745
messages=prompt,
740746
tool_specs=[tool_spec],
741747
system_prompt=system_prompt,
748+
tool_choice=cast(ToolChoice, {"any": {}}),
742749
**kwargs,
743750
)
744751
async for event in streaming.process_stream(response):

src/strands/models/litellm.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
from ..types.content import ContentBlock, Messages
1616
from ..types.streaming import StreamEvent
17-
from ..types.tools import ToolSpec
18-
from ._config_validation import validate_config_keys
17+
from ..types.tools import ToolChoice, ToolSpec
18+
from ._validation import validate_config_keys
1919
from .openai import OpenAIModel
2020

2121
logger = logging.getLogger(__name__)
@@ -114,6 +114,7 @@ async def stream(
114114
messages: Messages,
115115
tool_specs: Optional[list[ToolSpec]] = None,
116116
system_prompt: Optional[str] = None,
117+
tool_choice: ToolChoice | None = None,
117118
**kwargs: Any,
118119
) -> AsyncGenerator[StreamEvent, None]:
119120
"""Stream conversation with the LiteLLM model.
@@ -122,13 +123,14 @@ async def stream(
122123
messages: List of message objects to be processed by the model.
123124
tool_specs: List of tool specifications to make available to the model.
124125
system_prompt: System prompt to provide context to the model.
126+
tool_choice: Selection strategy for tool invocation.
125127
**kwargs: Additional keyword arguments for future extensibility.
126128
127129
Yields:
128130
Formatted message chunks from the model.
129131
"""
130132
logger.debug("formatting request")
131-
request = self.format_request(messages, tool_specs, system_prompt)
133+
request = self.format_request(messages, tool_specs, system_prompt, tool_choice)
132134
logger.debug("request=<%s>", request)
133135

134136
logger.debug("invoking model")

src/strands/models/llamaapi.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from ..types.content import ContentBlock, Messages
1919
from ..types.exceptions import ModelThrottledException
2020
from ..types.streaming import StreamEvent, Usage
21-
from ..types.tools import ToolResult, ToolSpec, ToolUse
22-
from ._config_validation import validate_config_keys
21+
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
22+
from ._validation import validate_config_keys, warn_on_tool_choice_not_supported
2323
from .model import Model
2424

2525
logger = logging.getLogger(__name__)
@@ -330,6 +330,7 @@ async def stream(
330330
messages: Messages,
331331
tool_specs: Optional[list[ToolSpec]] = None,
332332
system_prompt: Optional[str] = None,
333+
tool_choice: ToolChoice | None = None,
333334
**kwargs: Any,
334335
) -> AsyncGenerator[StreamEvent, None]:
335336
"""Stream conversation with the LlamaAPI model.
@@ -338,6 +339,8 @@ async def stream(
338339
messages: List of message objects to be processed by the model.
339340
tool_specs: List of tool specifications to make available to the model.
340341
system_prompt: System prompt to provide context to the model.
342+
tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for
343+
interface consistency but is currently ignored for this model provider.**
341344
**kwargs: Additional keyword arguments for future extensibility.
342345
343346
Yields:
@@ -346,6 +349,8 @@ async def stream(
346349
Raises:
347350
ModelThrottledException: When the model service is throttling requests from the client.
348351
"""
352+
warn_on_tool_choice_not_supported(tool_choice)
353+
349354
logger.debug("formatting request")
350355
request = self.format_request(messages, tool_specs, system_prompt)
351356
logger.debug("request=<%s>", request)

src/strands/models/mistral.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from ..types.content import ContentBlock, Messages
1616
from ..types.exceptions import ModelThrottledException
1717
from ..types.streaming import StopReason, StreamEvent
18-
from ..types.tools import ToolResult, ToolSpec, ToolUse
19-
from ._config_validation import validate_config_keys
18+
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
19+
from ._validation import validate_config_keys, warn_on_tool_choice_not_supported
2020
from .model import Model
2121

2222
logger = logging.getLogger(__name__)
@@ -397,6 +397,7 @@ async def stream(
397397
messages: Messages,
398398
tool_specs: Optional[list[ToolSpec]] = None,
399399
system_prompt: Optional[str] = None,
400+
tool_choice: ToolChoice | None = None,
400401
**kwargs: Any,
401402
) -> AsyncGenerator[StreamEvent, None]:
402403
"""Stream conversation with the Mistral model.
@@ -405,6 +406,8 @@ async def stream(
405406
messages: List of message objects to be processed by the model.
406407
tool_specs: List of tool specifications to make available to the model.
407408
system_prompt: System prompt to provide context to the model.
409+
tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for
410+
interface consistency but is currently ignored for this model provider.**
408411
**kwargs: Additional keyword arguments for future extensibility.
409412
410413
Yields:
@@ -413,6 +416,8 @@ async def stream(
413416
Raises:
414417
ModelThrottledException: When the model service is throttling requests.
415418
"""
419+
warn_on_tool_choice_not_supported(tool_choice)
420+
416421
logger.debug("formatting request")
417422
request = self.format_request(messages, tool_specs, system_prompt)
418423
logger.debug("request=<%s>", request)

src/strands/models/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from ..types.content import Messages
1010
from ..types.streaming import StreamEvent
11-
from ..types.tools import ToolSpec
11+
from ..types.tools import ToolChoice, ToolSpec
1212

1313
logger = logging.getLogger(__name__)
1414

@@ -70,6 +70,7 @@ def stream(
7070
messages: Messages,
7171
tool_specs: Optional[list[ToolSpec]] = None,
7272
system_prompt: Optional[str] = None,
73+
tool_choice: ToolChoice | None = None,
7374
**kwargs: Any,
7475
) -> AsyncIterable[StreamEvent]:
7576
"""Stream conversation with the model.
@@ -84,6 +85,7 @@ def stream(
8485
messages: List of message objects to be processed by the model.
8586
tool_specs: List of tool specifications to make available to the model.
8687
system_prompt: System prompt to provide context to the model.
88+
tool_choice: Selection strategy for tool invocation.
8789
**kwargs: Additional keyword arguments for future extensibility.
8890
8991
Yields:

src/strands/models/ollama.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313

1414
from ..types.content import ContentBlock, Messages
1515
from ..types.streaming import StopReason, StreamEvent
16-
from ..types.tools import ToolSpec
17-
from ._config_validation import validate_config_keys
16+
from ..types.tools import ToolChoice, ToolSpec
17+
from ._validation import validate_config_keys, warn_on_tool_choice_not_supported
1818
from .model import Model
1919

2020
logger = logging.getLogger(__name__)
@@ -287,6 +287,7 @@ async def stream(
287287
messages: Messages,
288288
tool_specs: Optional[list[ToolSpec]] = None,
289289
system_prompt: Optional[str] = None,
290+
tool_choice: ToolChoice | None = None,
290291
**kwargs: Any,
291292
) -> AsyncGenerator[StreamEvent, None]:
292293
"""Stream conversation with the Ollama model.
@@ -295,11 +296,15 @@ async def stream(
295296
messages: List of message objects to be processed by the model.
296297
tool_specs: List of tool specifications to make available to the model.
297298
system_prompt: System prompt to provide context to the model.
299+
tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for
300+
interface consistency but is currently ignored for this model provider.**
298301
**kwargs: Additional keyword arguments for future extensibility.
299302
300303
Yields:
301304
Formatted message chunks from the model.
302305
"""
306+
warn_on_tool_choice_not_supported(tool_choice)
307+
303308
logger.debug("formatting request")
304309
request = self.format_request(messages, tool_specs, system_prompt)
305310
logger.debug("request=<%s>", request)

0 commit comments

Comments
 (0)