1818from ..types .content import ContentBlock , Messages
1919from ..types .exceptions import ContextWindowOverflowException , ModelThrottledException
2020from ..types .streaming import StreamEvent
21- from ..types .tools import ToolSpec
22- from ._config_validation import validate_config_keys
21+ from ..types .tools import ToolChoice , ToolChoiceToolDict , ToolSpec
22+ from ._validation import validate_config_keys
2323from .model import Model
2424
2525logger = logging .getLogger (__name__ )
@@ -195,14 +195,19 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]:
195195 return formatted_messages
196196
197197 def format_request (
198- self , messages : Messages , tool_specs : Optional [list [ToolSpec ]] = None , system_prompt : Optional [str ] = None
198+ self ,
199+ messages : Messages ,
200+ tool_specs : Optional [list [ToolSpec ]] = None ,
201+ system_prompt : Optional [str ] = None ,
202+ tool_choice : ToolChoice | None = None ,
199203 ) -> dict [str , Any ]:
200204 """Format an Anthropic streaming request.
201205
202206 Args:
203207 messages: List of message objects to be processed by the model.
204208 tool_specs: List of tool specifications to make available to the model.
205209 system_prompt: System prompt to provide context to the model.
210+ tool_choice: Selection strategy for tool invocation.
206211
207212 Returns:
208213 An Anthropic streaming request.
@@ -223,10 +228,25 @@ def format_request(
223228 }
224229 for tool_spec in tool_specs or []
225230 ],
231+ ** (self ._format_tool_choice (tool_choice )),
226232 ** ({"system" : system_prompt } if system_prompt else {}),
227233 ** (self .config .get ("params" ) or {}),
228234 }
229235
236+ @staticmethod
237+ def _format_tool_choice (tool_choice : ToolChoice | None ) -> dict :
238+ if tool_choice is None :
239+ return {}
240+
241+ if "any" in tool_choice :
242+ return {"tool_choice" : {"type" : "any" }}
243+ elif "auto" in tool_choice :
244+ return {"tool_choice" : {"type" : "auto" }}
245+ elif "tool" in tool_choice :
246+ return {"tool_choice" : {"type" : "tool" , "name" : cast (ToolChoiceToolDict , tool_choice )["tool" ]["name" ]}}
247+ else :
248+ return {}
249+
230250 def format_chunk (self , event : dict [str , Any ]) -> StreamEvent :
231251 """Format the Anthropic response events into standardized message chunks.
232252
@@ -350,6 +370,7 @@ async def stream(
350370 messages : Messages ,
351371 tool_specs : Optional [list [ToolSpec ]] = None ,
352372 system_prompt : Optional [str ] = None ,
373+ tool_choice : ToolChoice | None = None ,
353374 ** kwargs : Any ,
354375 ) -> AsyncGenerator [StreamEvent , None ]:
355376 """Stream conversation with the Anthropic model.
@@ -358,6 +379,7 @@ async def stream(
358379 messages: List of message objects to be processed by the model.
359380 tool_specs: List of tool specifications to make available to the model.
360381 system_prompt: System prompt to provide context to the model.
382+ tool_choice: Selection strategy for tool invocation.
361383 **kwargs: Additional keyword arguments for future extensibility.
362384
363385 Yields:
@@ -368,7 +390,7 @@ async def stream(
368390 ModelThrottledException: If the request is throttled by Anthropic.
369391 """
370392 logger .debug ("formatting request" )
371- request = self .format_request (messages , tool_specs , system_prompt )
393+ request = self .format_request (messages , tool_specs , system_prompt , tool_choice )
372394 logger .debug ("request=<%s>" , request )
373395
374396 logger .debug ("invoking model" )
@@ -410,7 +432,13 @@ async def structured_output(
410432 """
411433 tool_spec = convert_pydantic_to_tool_spec (output_model )
412434
413- response = self .stream (messages = prompt , tool_specs = [tool_spec ], system_prompt = system_prompt , ** kwargs )
435+ response = self .stream (
436+ messages = prompt ,
437+ tool_specs = [tool_spec ],
438+ system_prompt = system_prompt ,
439+ tool_choice = cast (ToolChoice , {"any" : {}}),
440+ ** kwargs ,
441+ )
414442 async for event in process_stream (response ):
415443 yield event
416444
0 commit comments