diff --git a/.github/workflows/pr-and-push.yml b/.github/workflows/pr-and-push.yml
index 2b2d026f4..b558943dd 100644
--- a/.github/workflows/pr-and-push.yml
+++ b/.github/workflows/pr-and-push.yml
@@ -3,7 +3,7 @@ name: Pull Request and Push Action
on:
pull_request: # Safer than pull_request_target for untrusted code
branches: [ main ]
- types: [opened, synchronize, reopened, ready_for_review, review_requested, review_request_removed]
+ types: [opened, synchronize, reopened, ready_for_review]
push:
branches: [ main ] # Also run on direct pushes to main
concurrency:
diff --git a/README.md b/README.md
index 58c647f8d..5b545f969 100644
--- a/README.md
+++ b/README.md
@@ -21,6 +21,9 @@
+
+
+
Documentation
@@ -91,6 +94,17 @@ agent = Agent(tools=[word_count])
response = agent("How many words are in this sentence?")
```
+**Hot Reloading from Directory:**
+Enable automatic tool loading and reloading from the `./tools/` directory:
+
+```python
+from strands import Agent
+
+# Agent will watch ./tools/ directory for changes
+agent = Agent(load_tools_from_directory=True)
+response = agent("Use any tools you find in the tools directory")
+```
+
### MCP Support
Seamlessly integrate Model Context Protocol (MCP) servers:
diff --git a/pyproject.toml b/pyproject.toml
index 765e815ef..586a956af 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -29,7 +29,7 @@ dependencies = [
"boto3>=1.26.0,<2.0.0",
"botocore>=1.29.0,<2.0.0",
"docstring_parser>=0.15,<1.0",
- "mcp>=1.8.0,<2.0.0",
+ "mcp>=1.11.0,<2.0.0",
"pydantic>=2.0.0,<3.0.0",
"typing-extensions>=4.13.2,<5.0.0",
"watchdog>=6.0.0,<7.0.0",
@@ -89,8 +89,15 @@ writer = [
"writer-sdk>=2.2.0,<3.0.0"
]
+sagemaker = [
+ "boto3>=1.26.0,<2.0.0",
+ "botocore>=1.29.0,<2.0.0",
+ "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0"
+]
+
a2a = [
- "a2a-sdk[sql]>=0.2.16,<1.0.0",
+ "a2a-sdk>=0.3.0,<0.4.0",
+ "a2a-sdk[sql]>=0.3.0,<0.4.0",
"uvicorn>=0.34.2,<1.0.0",
"httpx>=0.28.1,<1.0.0",
"fastapi>=0.115.12,<1.0.0",
@@ -136,7 +143,7 @@ all = [
"opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0",
# a2a
- "a2a-sdk[sql]>=0.2.16,<1.0.0",
+ "a2a-sdk[sql]>=0.3.0,<0.4.0",
"uvicorn>=0.34.2,<1.0.0",
"httpx>=0.28.1,<1.0.0",
"fastapi>=0.115.12,<1.0.0",
@@ -148,7 +155,7 @@ all = [
source = "vcs"
[tool.hatch.envs.hatch-static-analysis]
-features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a"]
+features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a", "sagemaker"]
dependencies = [
"mypy>=1.15.0,<2.0.0",
"ruff>=0.11.6,<0.12.0",
@@ -171,7 +178,7 @@ lint-fix = [
]
[tool.hatch.envs.hatch-test]
-features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a"]
+features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a", "sagemaker"]
extra-dependencies = [
"moto>=5.1.0,<6.0.0",
"pytest>=8.0.0,<9.0.0",
@@ -187,7 +194,7 @@ extra-args = [
[tool.hatch.envs.dev]
dev-mode = true
-features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer", "a2a"]
+features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer", "a2a", "sagemaker"]
[[tool.hatch.envs.hatch-test.matrix]]
python = ["3.13", "3.12", "3.11", "3.10"]
diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py
index 8756a1022..2c1ee7847 100644
--- a/src/strands/agent/conversation_manager/conversation_manager.py
+++ b/src/strands/agent/conversation_manager/conversation_manager.py
@@ -36,7 +36,7 @@ def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]]
Args:
state: Previous state of the conversation manager
Returns:
- Optional list of messages to prepend to the agents messages. By defualt returns None.
+ Optional list of messages to prepend to the agents messages. By default returns None.
"""
if state.get("__name__") != self.__class__.__name__:
raise ValueError("Invalid conversation manager state.")
diff --git a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py
new file mode 100644
index 000000000..ab6fb4abe
--- /dev/null
+++ b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py
@@ -0,0 +1,71 @@
+"""Message recovery utilities for handling max token limit scenarios.
+
+This module provides functionality to recover and clean up incomplete messages that occur
+when model responses are truncated due to maximum token limits being reached. It specifically
+handles cases where tool use blocks are incomplete or malformed due to truncation.
+"""
+
+import logging
+
+from ..types.content import ContentBlock, Message
+from ..types.tools import ToolUse
+
+logger = logging.getLogger(__name__)
+
+
+def recover_message_on_max_tokens_reached(message: Message) -> Message:
+ """Recover and clean up messages when max token limits are reached.
+
+ When a model response is truncated due to maximum token limits, all tool use blocks
+ should be replaced with informative error messages since they may be incomplete or
+ unreliable. This function inspects the message content and:
+
+ 1. Identifies all tool use blocks (regardless of validity)
+ 2. Replaces all tool uses with informative error messages
+ 3. Preserves all non-tool content blocks (text, images, etc.)
+ 4. Returns a cleaned message suitable for conversation history
+
+ This recovery mechanism ensures that the conversation can continue gracefully even when
+ model responses are truncated, providing clear feedback about what happened and preventing
+ potentially incomplete or corrupted tool executions.
+
+ Args:
+ message: The potentially incomplete message from the model that was truncated
+ due to max token limits.
+
+ Returns:
+ A cleaned Message with all tool uses replaced by explanatory text content.
+ The returned message maintains the same role as the input message.
+
+ Example:
+ If a message contains any tool use (complete or incomplete):
+ ```
+ {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}, "toolUseId": "123"}}
+ ```
+
+ It will be replaced with:
+ ```
+ {"text": "The selected tool calculator's tool use was incomplete due to maximum token limits being reached."}
+ ```
+ """
+ logger.info("handling max_tokens stop reason - replacing all tool uses with error messages")
+
+ valid_content: list[ContentBlock] = []
+ for content in message["content"] or []:
+ tool_use: ToolUse | None = content.get("toolUse")
+ if not tool_use:
+ valid_content.append(content)
+ continue
+
+ # Replace all tool uses with error messages when max_tokens is reached
+ display_name = tool_use.get("name") or ""
+ logger.warning("tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name)
+
+ valid_content.append(
+ {
+ "text": f"The selected tool {display_name}'s tool use was incomplete due "
+ f"to maximum token limits being reached."
+ }
+ )
+
+ return {"content": valid_content, "role": message["role"]}
diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py
index ffcb6a5c9..b36f73155 100644
--- a/src/strands/event_loop/event_loop.py
+++ b/src/strands/event_loop/event_loop.py
@@ -28,9 +28,15 @@
from ..telemetry.tracer import get_tracer
from ..tools.executor import run_tools, validate_and_prepare_tools
from ..types.content import Message
-from ..types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException
+from ..types.exceptions import (
+ ContextWindowOverflowException,
+ EventLoopException,
+ MaxTokensReachedException,
+ ModelThrottledException,
+)
from ..types.streaming import Metrics, StopReason
from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse
+from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached
from .streaming import stream_messages
if TYPE_CHECKING:
@@ -151,6 +157,9 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
)
)
+ if stop_reason == "max_tokens":
+ message = recover_message_on_max_tokens_reached(message)
+
if model_invoke_span:
tracer.end_model_invoke_span(model_invoke_span, message, usage, stop_reason)
break # Success! Break out of retry loop
@@ -200,6 +209,22 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
agent.event_loop_metrics.update_usage(usage)
agent.event_loop_metrics.update_metrics(metrics)
+ if stop_reason == "max_tokens":
+ """
+ Handle max_tokens limit reached by the model.
+
+ When the model reaches its maximum token limit, this represents a potentially unrecoverable
+ state where the model's response was truncated. By default, Strands fails hard with an
+ MaxTokensReachedException to maintain consistency with other failure types.
+ """
+ raise MaxTokensReachedException(
+ message=(
+ "Agent has reached an unrecoverable state due to max_tokens limit. "
+ "For more information see: "
+ "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception"
+ )
+ )
+
# If the model is requesting to use tools
if stop_reason == "tool_use":
# Handle tool execution
@@ -231,7 +256,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
# Don't yield or log the exception - we already did it when we
# raised the exception and we don't need that duplication.
raise
- except ContextWindowOverflowException as e:
+ except (ContextWindowOverflowException, MaxTokensReachedException) as e:
+ # Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException
if cycle_span:
tracer.end_span_with_error(cycle_span, str(e), e)
raise e
diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py
index eb72becfd..975fca3e9 100644
--- a/src/strands/models/anthropic.py
+++ b/src/strands/models/anthropic.py
@@ -414,7 +414,7 @@ async def structured_output(
stop_reason, messages, _, _ = event["stop"]
if stop_reason != "tool_use":
- raise ValueError("No valid tool use or tool use input was found in the Anthropic response.")
+ raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".')
content = messages["content"]
output_response: dict[str, Any] | None = None
diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py
index 679f1ea3d..4ea1453a4 100644
--- a/src/strands/models/bedrock.py
+++ b/src/strands/models/bedrock.py
@@ -17,10 +17,10 @@
from ..event_loop import streaming
from ..tools import convert_pydantic_to_tool_spec
-from ..types.content import Messages
+from ..types.content import ContentBlock, Message, Messages
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.streaming import StreamEvent
-from ..types.tools import ToolSpec
+from ..types.tools import ToolResult, ToolSpec
from .model import Model
logger = logging.getLogger(__name__)
@@ -181,7 +181,7 @@ def format_request(
"""
return {
"modelId": self.config["model_id"],
- "messages": messages,
+ "messages": self._format_bedrock_messages(messages),
"system": [
*([{"text": system_prompt}] if system_prompt else []),
*([{"cachePoint": {"type": self.config["cache_prompt"]}}] if self.config.get("cache_prompt") else []),
@@ -246,6 +246,53 @@ def format_request(
),
}
+ def _format_bedrock_messages(self, messages: Messages) -> Messages:
+ """Format messages for Bedrock API compatibility.
+
+ This function ensures messages conform to Bedrock's expected format by:
+ - Cleaning tool result content blocks by removing additional fields that may be
+ useful for retaining information in hooks but would cause Bedrock validation
+ exceptions when presented with unexpected fields
+ - Ensuring all message content blocks are properly formatted for the Bedrock API
+
+ Args:
+ messages: List of messages to format
+
+ Returns:
+ Messages formatted for Bedrock API compatibility
+
+ Note:
+ Bedrock will throw validation exceptions when presented with additional
+ unexpected fields in tool result blocks.
+ https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html
+ """
+ cleaned_messages = []
+
+ for message in messages:
+ cleaned_content: list[ContentBlock] = []
+
+ for content_block in message["content"]:
+ if "toolResult" in content_block:
+ # Create a new content block with only the cleaned toolResult
+ tool_result: ToolResult = content_block["toolResult"]
+
+ # Keep only the required fields for Bedrock
+ cleaned_tool_result = ToolResult(
+ content=tool_result["content"], toolUseId=tool_result["toolUseId"], status=tool_result["status"]
+ )
+
+ cleaned_block: ContentBlock = {"toolResult": cleaned_tool_result}
+ cleaned_content.append(cleaned_block)
+ else:
+ # Keep other content blocks as-is
+ cleaned_content.append(content_block)
+
+ # Create new message with cleaned content
+ cleaned_message: Message = Message(content=cleaned_content, role=message["role"])
+ cleaned_messages.append(cleaned_message)
+
+ return cleaned_messages
+
def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool:
"""Check if guardrail data contains any blocked policies.
@@ -584,7 +631,7 @@ async def structured_output(
stop_reason, messages, _, _ = event["stop"]
if stop_reason != "tool_use":
- raise ValueError("No valid tool use or tool use input was found in the Bedrock response.")
+ raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".')
content = messages["content"]
output_response: dict[str, Any] | None = None
diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py
new file mode 100644
index 000000000..9cfe27d9e
--- /dev/null
+++ b/src/strands/models/sagemaker.py
@@ -0,0 +1,598 @@
+"""Amazon SageMaker model provider."""
+
+import json
+import logging
+import os
+from dataclasses import dataclass
+from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union, cast
+
+import boto3
+from botocore.config import Config as BotocoreConfig
+from mypy_boto3_sagemaker_runtime import SageMakerRuntimeClient
+from pydantic import BaseModel
+from typing_extensions import Unpack, override
+
+from ..types.content import ContentBlock, Messages
+from ..types.streaming import StreamEvent
+from ..types.tools import ToolResult, ToolSpec
+from .openai import OpenAIModel
+
+T = TypeVar("T", bound=BaseModel)
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class UsageMetadata:
+ """Usage metadata for the model.
+
+ Attributes:
+ total_tokens: Total number of tokens used in the request
+ completion_tokens: Number of tokens used in the completion
+ prompt_tokens: Number of tokens used in the prompt
+ prompt_tokens_details: Additional information about the prompt tokens (optional)
+ """
+
+ total_tokens: int
+ completion_tokens: int
+ prompt_tokens: int
+ prompt_tokens_details: Optional[int] = 0
+
+
+@dataclass
+class FunctionCall:
+ """Function call for the model.
+
+ Attributes:
+ name: Name of the function to call
+ arguments: Arguments to pass to the function
+ """
+
+ name: Union[str, dict[Any, Any]]
+ arguments: Union[str, dict[Any, Any]]
+
+ def __init__(self, **kwargs: dict[str, str]):
+ """Initialize function call.
+
+ Args:
+ **kwargs: Keyword arguments for the function call.
+ """
+ self.name = kwargs.get("name", "")
+ self.arguments = kwargs.get("arguments", "")
+
+
+@dataclass
+class ToolCall:
+ """Tool call for the model object.
+
+ Attributes:
+ id: Tool call ID
+ type: Tool call type
+ function: Tool call function
+ """
+
+ id: str
+ type: Literal["function"]
+ function: FunctionCall
+
+ def __init__(self, **kwargs: dict):
+ """Initialize tool call object.
+
+ Args:
+ **kwargs: Keyword arguments for the tool call.
+ """
+ self.id = str(kwargs.get("id", ""))
+ self.type = "function"
+ self.function = FunctionCall(**kwargs.get("function", {"name": "", "arguments": ""}))
+
+
+class SageMakerAIModel(OpenAIModel):
+ """Amazon SageMaker model provider implementation."""
+
+ client: SageMakerRuntimeClient # type: ignore[assignment]
+
+ class SageMakerAIPayloadSchema(TypedDict, total=False):
+ """Payload schema for the Amazon SageMaker AI model.
+
+ Attributes:
+ max_tokens: Maximum number of tokens to generate in the completion
+ stream: Whether to stream the response
+ temperature: Sampling temperature to use for the model (optional)
+ top_p: Nucleus sampling parameter (optional)
+ top_k: Top-k sampling parameter (optional)
+ stop: List of stop sequences to use for the model (optional)
+ tool_results_as_user_messages: Convert tool result to user messages (optional)
+ additional_args: Additional request parameters, as supported by https://bit.ly/djl-lmi-request-schema
+ """
+
+ max_tokens: int
+ stream: bool
+ temperature: Optional[float]
+ top_p: Optional[float]
+ top_k: Optional[int]
+ stop: Optional[list[str]]
+ tool_results_as_user_messages: Optional[bool]
+ additional_args: Optional[dict[str, Any]]
+
+ class SageMakerAIEndpointConfig(TypedDict, total=False):
+ """Configuration options for SageMaker models.
+
+ Attributes:
+ endpoint_name: The name of the SageMaker endpoint to invoke
+ inference_component_name: The name of the inference component to use
+
+ additional_args: Other request parameters, as supported by https://bit.ly/sagemaker-invoke-endpoint-params
+ """
+
+ endpoint_name: str
+ region_name: str
+ inference_component_name: Union[str, None]
+ target_model: Union[Optional[str], None]
+ target_variant: Union[Optional[str], None]
+ additional_args: Optional[dict[str, Any]]
+
+ def __init__(
+ self,
+ endpoint_config: SageMakerAIEndpointConfig,
+ payload_config: SageMakerAIPayloadSchema,
+ boto_session: Optional[boto3.Session] = None,
+ boto_client_config: Optional[BotocoreConfig] = None,
+ ):
+ """Initialize provider instance.
+
+ Args:
+ endpoint_config: Endpoint configuration for SageMaker.
+ payload_config: Payload configuration for the model.
+ boto_session: Boto Session to use when calling the SageMaker Runtime.
+ boto_client_config: Configuration to use when creating the SageMaker-Runtime Boto Client.
+ """
+ payload_config.setdefault("stream", True)
+ payload_config.setdefault("tool_results_as_user_messages", False)
+ self.endpoint_config = dict(endpoint_config)
+ self.payload_config = dict(payload_config)
+ logger.debug(
+ "endpoint_config=<%s> payload_config=<%s> | initializing", self.endpoint_config, self.payload_config
+ )
+
+ region = self.endpoint_config.get("region_name") or os.getenv("AWS_REGION") or "us-west-2"
+ session = boto_session or boto3.Session(region_name=str(region))
+
+ # Add strands-agents to the request user agent
+ if boto_client_config:
+ existing_user_agent = getattr(boto_client_config, "user_agent_extra", None)
+
+ # Append 'strands-agents' to existing user_agent_extra or set it if not present
+ new_user_agent = f"{existing_user_agent} strands-agents" if existing_user_agent else "strands-agents"
+
+ client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent))
+ else:
+ client_config = BotocoreConfig(user_agent_extra="strands-agents")
+
+ self.client = session.client(
+ service_name="sagemaker-runtime",
+ config=client_config,
+ )
+
+ @override
+ def update_config(self, **endpoint_config: Unpack[SageMakerAIEndpointConfig]) -> None: # type: ignore[override]
+ """Update the Amazon SageMaker model configuration with the provided arguments.
+
+ Args:
+ **endpoint_config: Configuration overrides.
+ """
+ self.endpoint_config.update(endpoint_config)
+
+ @override
+ def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: ignore[override]
+ """Get the Amazon SageMaker model configuration.
+
+ Returns:
+ The Amazon SageMaker model configuration.
+ """
+ return cast(SageMakerAIModel.SageMakerAIEndpointConfig, self.endpoint_config)
+
+ @override
+ def format_request(
+ self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
+ ) -> dict[str, Any]:
+ """Format an Amazon SageMaker chat streaming request.
+
+ Args:
+ messages: List of message objects to be processed by the model.
+ tool_specs: List of tool specifications to make available to the model.
+ system_prompt: System prompt to provide context to the model.
+
+ Returns:
+ An Amazon SageMaker chat streaming request.
+ """
+ formatted_messages = self.format_request_messages(messages, system_prompt)
+
+ payload = {
+ "messages": formatted_messages,
+ "tools": [
+ {
+ "type": "function",
+ "function": {
+ "name": tool_spec["name"],
+ "description": tool_spec["description"],
+ "parameters": tool_spec["inputSchema"]["json"],
+ },
+ }
+ for tool_spec in tool_specs or []
+ ],
+ # Add payload configuration parameters
+ **{
+ k: v
+ for k, v in self.payload_config.items()
+ if k not in ["additional_args", "tool_results_as_user_messages"]
+ },
+ }
+
+ # Remove tools and tool_choice if tools = []
+ if not payload["tools"]:
+ payload.pop("tools")
+ payload.pop("tool_choice", None)
+ else:
+ # Ensure the model can use tools when available
+ payload["tool_choice"] = "auto"
+
+ for message in payload["messages"]: # type: ignore
+ # Assistant message must have either content or tool_calls, but not both
+ if message.get("role", "") == "assistant" and message.get("tool_calls", []) != []:
+ message.pop("content", None)
+ if message.get("role") == "tool" and self.payload_config.get("tool_results_as_user_messages", False):
+ # Convert tool message to user message
+ tool_call_id = message.get("tool_call_id", "ABCDEF")
+ content = message.get("content", "")
+ message = {"role": "user", "content": f"Tool call ID '{tool_call_id}' returned: {content}"}
+ # Cannot have both reasoning_text and text - if "text", content becomes an array of content["text"]
+ for c in message.get("content", []):
+ if "text" in c:
+ message["content"] = [c]
+ break
+ # Cast message content to string for TGI compatibility
+ # message["content"] = str(message.get("content", ""))
+
+ logger.info("payload=<%s>", json.dumps(payload, indent=2))
+ # Format the request according to the SageMaker Runtime API requirements
+ request = {
+ "EndpointName": self.endpoint_config["endpoint_name"],
+ "Body": json.dumps(payload),
+ "ContentType": "application/json",
+ "Accept": "application/json",
+ }
+
+ # Add optional SageMaker parameters if provided
+ if self.endpoint_config.get("inference_component_name"):
+ request["InferenceComponentName"] = self.endpoint_config["inference_component_name"]
+ if self.endpoint_config.get("target_model"):
+ request["TargetModel"] = self.endpoint_config["target_model"]
+ if self.endpoint_config.get("target_variant"):
+ request["TargetVariant"] = self.endpoint_config["target_variant"]
+
+ # Add additional args if provided
+ if self.endpoint_config.get("additional_args"):
+ request.update(self.endpoint_config["additional_args"].__dict__)
+
+ return request
+
+ @override
+ async def stream(
+ self,
+ messages: Messages,
+ tool_specs: Optional[list[ToolSpec]] = None,
+ system_prompt: Optional[str] = None,
+ **kwargs: Any,
+ ) -> AsyncGenerator[StreamEvent, None]:
+ """Stream conversation with the SageMaker model.
+
+ Args:
+ messages: List of message objects to be processed by the model.
+ tool_specs: List of tool specifications to make available to the model.
+ system_prompt: System prompt to provide context to the model.
+ **kwargs: Additional keyword arguments for future extensibility.
+
+ Yields:
+ Formatted message chunks from the model.
+ """
+ logger.debug("formatting request")
+ request = self.format_request(messages, tool_specs, system_prompt)
+ logger.debug("formatted request=<%s>", request)
+
+ logger.debug("invoking model")
+ try:
+ if self.payload_config.get("stream", True):
+ response = self.client.invoke_endpoint_with_response_stream(**request)
+
+ # Message start
+ yield self.format_chunk({"chunk_type": "message_start"})
+
+ # Parse the content
+ finish_reason = ""
+ partial_content = ""
+ tool_calls: dict[int, list[Any]] = {}
+ has_text_content = False
+ text_content_started = False
+ reasoning_content_started = False
+
+ for event in response["Body"]:
+ chunk = event["PayloadPart"]["Bytes"].decode("utf-8")
+ partial_content += chunk[6:] if chunk.startswith("data: ") else chunk # TGI fix
+ logger.info("chunk=<%s>", partial_content)
+ try:
+ content = json.loads(partial_content)
+ partial_content = ""
+ choice = content["choices"][0]
+ logger.info("choice=<%s>", json.dumps(choice, indent=2))
+
+ # Handle text content
+ if choice["delta"].get("content", None):
+ if not text_content_started:
+ yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
+ text_content_started = True
+ has_text_content = True
+ yield self.format_chunk(
+ {
+ "chunk_type": "content_delta",
+ "data_type": "text",
+ "data": choice["delta"]["content"],
+ }
+ )
+
+ # Handle reasoning content
+ if choice["delta"].get("reasoning_content", None):
+ if not reasoning_content_started:
+ yield self.format_chunk(
+ {"chunk_type": "content_start", "data_type": "reasoning_content"}
+ )
+ reasoning_content_started = True
+ yield self.format_chunk(
+ {
+ "chunk_type": "content_delta",
+ "data_type": "reasoning_content",
+ "data": choice["delta"]["reasoning_content"],
+ }
+ )
+
+ # Handle tool calls
+ generated_tool_calls = choice["delta"].get("tool_calls", [])
+ if not isinstance(generated_tool_calls, list):
+ generated_tool_calls = [generated_tool_calls]
+ for tool_call in generated_tool_calls:
+ tool_calls.setdefault(tool_call["index"], []).append(tool_call)
+
+ if choice["finish_reason"] is not None:
+ finish_reason = choice["finish_reason"]
+ break
+
+ if choice.get("usage", None):
+ yield self.format_chunk(
+ {"chunk_type": "metadata", "data": UsageMetadata(**choice["usage"])}
+ )
+
+ except json.JSONDecodeError:
+ # Continue accumulating content until we have valid JSON
+ continue
+
+ # Close reasoning content if it was started
+ if reasoning_content_started:
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "reasoning_content"})
+
+ # Close text content if it was started
+ if text_content_started:
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
+
+ # Handle tool calling
+ logger.info("tool_calls=<%s>", json.dumps(tool_calls, indent=2))
+ for tool_deltas in tool_calls.values():
+ if not tool_deltas[0]["function"].get("name", None):
+ raise Exception("The model did not provide a tool name.")
+ yield self.format_chunk(
+ {"chunk_type": "content_start", "data_type": "tool", "data": ToolCall(**tool_deltas[0])}
+ )
+ for tool_delta in tool_deltas:
+ yield self.format_chunk(
+ {"chunk_type": "content_delta", "data_type": "tool", "data": ToolCall(**tool_delta)}
+ )
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
+
+ # If no content was generated at all, ensure we have empty text content
+ if not has_text_content and not tool_calls:
+ yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
+
+ # Message close
+ yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason})
+
+ else:
+ # Not all SageMaker AI models support streaming!
+ response = self.client.invoke_endpoint(**request) # type: ignore[assignment]
+ final_response_json = json.loads(response["Body"].read().decode("utf-8")) # type: ignore[attr-defined]
+ logger.info("response=<%s>", json.dumps(final_response_json, indent=2))
+
+ # Obtain the key elements from the response
+ message = final_response_json["choices"][0]["message"]
+ message_stop_reason = final_response_json["choices"][0]["finish_reason"]
+
+ # Message start
+ yield self.format_chunk({"chunk_type": "message_start"})
+
+ # Handle text
+ if message.get("content", ""):
+ yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
+ yield self.format_chunk(
+ {"chunk_type": "content_delta", "data_type": "text", "data": message["content"]}
+ )
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
+
+ # Handle reasoning content
+ if message.get("reasoning_content", None):
+ yield self.format_chunk({"chunk_type": "content_start", "data_type": "reasoning_content"})
+ yield self.format_chunk(
+ {
+ "chunk_type": "content_delta",
+ "data_type": "reasoning_content",
+ "data": message["reasoning_content"],
+ }
+ )
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "reasoning_content"})
+
+ # Handle the tool calling, if any
+ if message.get("tool_calls", None) or message_stop_reason == "tool_calls":
+ if not isinstance(message["tool_calls"], list):
+ message["tool_calls"] = [message["tool_calls"]]
+ for tool_call in message["tool_calls"]:
+ # if arguments of tool_call is not str, cast it
+ if not isinstance(tool_call["function"]["arguments"], str):
+ tool_call["function"]["arguments"] = json.dumps(tool_call["function"]["arguments"])
+ yield self.format_chunk(
+ {"chunk_type": "content_start", "data_type": "tool", "data": ToolCall(**tool_call)}
+ )
+ yield self.format_chunk(
+ {"chunk_type": "content_delta", "data_type": "tool", "data": ToolCall(**tool_call)}
+ )
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
+ message_stop_reason = "tool_calls"
+
+ # Message close
+ yield self.format_chunk({"chunk_type": "message_stop", "data": message_stop_reason})
+ # Handle usage metadata
+ if final_response_json.get("usage", None):
+ yield self.format_chunk(
+ {"chunk_type": "metadata", "data": UsageMetadata(**final_response_json.get("usage", None))}
+ )
+ except (
+ self.client.exceptions.InternalFailure,
+ self.client.exceptions.ServiceUnavailable,
+ self.client.exceptions.ValidationError,
+ self.client.exceptions.ModelError,
+ self.client.exceptions.InternalDependencyException,
+ self.client.exceptions.ModelNotReadyException,
+ ) as e:
+ logger.error("SageMaker error: %s", str(e))
+ raise e
+
+ logger.debug("finished streaming response from model")
+
+ @override
+ @classmethod
+ def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]:
+ """Format a SageMaker compatible tool message.
+
+ Args:
+ tool_result: Tool result collected from a tool execution.
+
+ Returns:
+ SageMaker compatible tool message with content as a string.
+ """
+ # Convert content blocks to a simple string for SageMaker compatibility
+ content_parts = []
+ for content in tool_result["content"]:
+ if "json" in content:
+ content_parts.append(json.dumps(content["json"]))
+ elif "text" in content:
+ content_parts.append(content["text"])
+ else:
+ # Handle other content types by converting to string
+ content_parts.append(str(content))
+
+ content_string = " ".join(content_parts)
+
+ return {
+ "role": "tool",
+ "tool_call_id": tool_result["toolUseId"],
+ "content": content_string, # String instead of list
+ }
+
+ @override
+ @classmethod
+ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]:
+ """Format a content block.
+
+ Args:
+ content: Message content.
+
+ Returns:
+ Formatted content block.
+
+ Raises:
+ TypeError: If the content block type cannot be converted to a SageMaker-compatible format.
+ """
+ # if "text" in content and not isinstance(content["text"], str):
+ # return {"type": "text", "text": str(content["text"])}
+
+ if "reasoningContent" in content and content["reasoningContent"]:
+ return {
+ "signature": content["reasoningContent"].get("reasoningText", {}).get("signature", ""),
+ "thinking": content["reasoningContent"].get("reasoningText", {}).get("text", ""),
+ "type": "thinking",
+ }
+ elif not content.get("reasoningContent", None):
+ content.pop("reasoningContent", None)
+
+ if "video" in content:
+ return {
+ "type": "video_url",
+ "video_url": {
+ "detail": "auto",
+ "url": content["video"]["source"]["bytes"],
+ },
+ }
+
+ return super().format_request_message_content(content)
+
+ @override
+ async def structured_output(
+ self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
+ ) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
+ """Get structured output from the model.
+
+ Args:
+ output_model: The output model to use for the agent.
+ prompt: The prompt messages to use for the agent.
+ system_prompt: System prompt to provide context to the model.
+ **kwargs: Additional keyword arguments for future extensibility.
+
+ Yields:
+ Model events with the last being the structured output.
+ """
+ # Format the request for structured output
+ request = self.format_request(prompt, system_prompt=system_prompt)
+
+ # Parse the payload to add response format
+ payload = json.loads(request["Body"])
+ payload["response_format"] = {
+ "type": "json_schema",
+ "json_schema": {"name": output_model.__name__, "schema": output_model.model_json_schema(), "strict": True},
+ }
+ request["Body"] = json.dumps(payload)
+
+ try:
+ # Use non-streaming mode for structured output
+ response = self.client.invoke_endpoint(**request)
+ final_response_json = json.loads(response["Body"].read().decode("utf-8"))
+
+ # Extract the structured content
+ message = final_response_json["choices"][0]["message"]
+
+ if message.get("content"):
+ try:
+ # Parse the JSON content and create the output model instance
+ content_data = json.loads(message["content"])
+ parsed_output = output_model(**content_data)
+ yield {"output": parsed_output}
+ except (json.JSONDecodeError, TypeError, ValueError) as e:
+ raise ValueError(f"Failed to parse structured output: {e}") from e
+ else:
+ raise ValueError("No content found in SageMaker response")
+
+ except (
+ self.client.exceptions.InternalFailure,
+ self.client.exceptions.ServiceUnavailable,
+ self.client.exceptions.ValidationError,
+ self.client.exceptions.ModelError,
+ self.client.exceptions.InternalDependencyException,
+ self.client.exceptions.ModelNotReadyException,
+ ) as e:
+ logger.error("SageMaker structured output error: %s", str(e))
+ raise ValueError(f"SageMaker structured output error: {str(e)}") from e
diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py
index 00eb4764f..5bf9cbfe9 100644
--- a/src/strands/multiagent/a2a/executor.py
+++ b/src/strands/multiagent/a2a/executor.py
@@ -4,7 +4,7 @@
to be used as an executor in the A2A protocol. It handles the execution of agent
requests and the conversion of Strands Agent streamed responses to A2A events.
-The A2A AgentExecutor ensures clients recieve responses for synchronous and
+The A2A AgentExecutor ensures clients receive responses for synchronous and
streamed requests to the A2AServer.
"""
@@ -61,7 +61,7 @@ async def execute(
task = new_task(context.message) # type: ignore
await event_queue.enqueue_event(task)
- updater = TaskUpdater(event_queue, task.id, task.contextId)
+ updater = TaskUpdater(event_queue, task.id, task.context_id)
try:
await self._execute_streaming(context, updater)
diff --git a/src/strands/multiagent/a2a/server.py b/src/strands/multiagent/a2a/server.py
index de891499d..fa7b6b887 100644
--- a/src/strands/multiagent/a2a/server.py
+++ b/src/strands/multiagent/a2a/server.py
@@ -6,6 +6,7 @@
import logging
from typing import Any, Literal
+from urllib.parse import urlparse
import uvicorn
from a2a.server.apps import A2AFastAPIApplication, A2AStarletteApplication
@@ -31,6 +32,8 @@ def __init__(
# AgentCard
host: str = "0.0.0.0",
port: int = 9000,
+ http_url: str | None = None,
+ serve_at_root: bool = False,
version: str = "0.0.1",
skills: list[AgentSkill] | None = None,
):
@@ -40,13 +43,34 @@ def __init__(
agent: The Strands Agent to wrap with A2A compatibility.
host: The hostname or IP address to bind the A2A server to. Defaults to "0.0.0.0".
port: The port to bind the A2A server to. Defaults to 9000.
+ http_url: The public HTTP URL where this agent will be accessible. If provided,
+ this overrides the generated URL from host/port and enables automatic
+ path-based mounting for load balancer scenarios.
+ Example: "http://my-alb.amazonaws.com/agent1"
+ serve_at_root: If True, forces the server to serve at root path regardless of
+ http_url path component. Use this when your load balancer strips path prefixes.
+ Defaults to False.
version: The version of the agent. Defaults to "0.0.1".
skills: The list of capabilities or functions the agent can perform.
"""
self.host = host
self.port = port
- self.http_url = f"http://{self.host}:{self.port}/"
self.version = version
+
+ if http_url:
+ # Parse the provided URL to extract components for mounting
+ self.public_base_url, self.mount_path = self._parse_public_url(http_url)
+ self.http_url = http_url.rstrip("/") + "/"
+
+ # Override mount path if serve_at_root is requested
+ if serve_at_root:
+ self.mount_path = ""
+ else:
+ # Fall back to constructing the URL from host and port
+ self.public_base_url = f"http://{host}:{port}"
+ self.http_url = f"{self.public_base_url}/"
+ self.mount_path = ""
+
self.strands_agent = agent
self.name = self.strands_agent.name
self.description = self.strands_agent.description
@@ -58,6 +82,25 @@ def __init__(
self._agent_skills = skills
logger.info("Strands' integration with A2A is experimental. Be aware of frequent breaking changes.")
+ def _parse_public_url(self, url: str) -> tuple[str, str]:
+ """Parse the public URL into base URL and mount path components.
+
+ Args:
+ url: The full public URL (e.g., "http://my-alb.amazonaws.com/agent1")
+
+ Returns:
+ tuple: (base_url, mount_path) where base_url is the scheme+netloc
+ and mount_path is the path component
+
+ Example:
+ _parse_public_url("http://my-alb.amazonaws.com/agent1")
+ Returns: ("http://my-alb.amazonaws.com", "/agent1")
+ """
+ parsed = urlparse(url.rstrip("/"))
+ base_url = f"{parsed.scheme}://{parsed.netloc}"
+ mount_path = parsed.path if parsed.path != "/" else ""
+ return base_url, mount_path
+
@property
def public_agent_card(self) -> AgentCard:
"""Get the public AgentCard for this agent.
@@ -119,24 +162,42 @@ def agent_skills(self, skills: list[AgentSkill]) -> None:
def to_starlette_app(self) -> Starlette:
"""Create a Starlette application for serving this agent via HTTP.
- This method creates a Starlette application that can be used to serve
- the agent via HTTP using the A2A protocol.
+ Automatically handles path-based mounting if a mount path was derived
+ from the http_url parameter.
Returns:
Starlette: A Starlette application configured to serve this agent.
"""
- return A2AStarletteApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build()
+ a2a_app = A2AStarletteApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build()
+
+ if self.mount_path:
+ # Create parent app and mount the A2A app at the specified path
+ parent_app = Starlette()
+ parent_app.mount(self.mount_path, a2a_app)
+ logger.info("Mounting A2A server at path: %s", self.mount_path)
+ return parent_app
+
+ return a2a_app
def to_fastapi_app(self) -> FastAPI:
"""Create a FastAPI application for serving this agent via HTTP.
- This method creates a FastAPI application that can be used to serve
- the agent via HTTP using the A2A protocol.
+ Automatically handles path-based mounting if a mount path was derived
+ from the http_url parameter.
Returns:
FastAPI: A FastAPI application configured to serve this agent.
"""
- return A2AFastAPIApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build()
+ a2a_app = A2AFastAPIApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build()
+
+ if self.mount_path:
+ # Create parent app and mount the A2A app at the specified path
+ parent_app = FastAPI()
+ parent_app.mount(self.mount_path, a2a_app)
+ logger.info("Mounting A2A server at path: %s", self.mount_path)
+ return parent_app
+
+ return a2a_app
def serve(
self,
diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py
index b32cb00e6..fec2f0761 100644
--- a/src/strands/session/file_session_manager.py
+++ b/src/strands/session/file_session_manager.py
@@ -23,6 +23,7 @@ class FileSessionManager(RepositorySessionManager, SessionRepository):
"""File-based session manager for local filesystem storage.
Creates the following filesystem structure for the session storage:
+ ```bash
//
└── session_/
├── session.json # Session metadata
@@ -32,7 +33,7 @@ class FileSessionManager(RepositorySessionManager, SessionRepository):
└── messages/
├── message_.json
└── message_.json
-
+ ```
"""
def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs: Any):
diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py
index 487335ac9..75058b251 100644
--- a/src/strands/session/repository_session_manager.py
+++ b/src/strands/session/repository_session_manager.py
@@ -32,7 +32,7 @@ def __init__(self, session_id: str, session_repository: SessionRepository, **kwa
Args:
session_id: ID to use for the session. A new session with this id will be created if it does
- not exist in the reposiory yet
+ not exist in the repository yet
session_repository: Underlying session repository to use to store the sessions state.
**kwargs: Additional keyword arguments for future extensibility.
@@ -133,15 +133,13 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None:
agent.state = AgentState(session_agent.state)
# Restore the conversation manager to its previous state, and get the optional prepend messages
- prepend_messsages = agent.conversation_manager.restore_from_session(
- session_agent.conversation_manager_state
- )
+ prepend_messages = agent.conversation_manager.restore_from_session(session_agent.conversation_manager_state)
- if prepend_messsages is None:
- prepend_messsages = []
+ if prepend_messages is None:
+ prepend_messages = []
# List the messages currently in the session, using an offset of the messages previously removed
- # by the converstaion manager.
+ # by the conversation manager.
session_messages = self.session_repository.list_messages(
session_id=self.session_id,
agent_id=agent.agent_id,
@@ -150,5 +148,5 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None:
if len(session_messages) > 0:
self._latest_agent_message[agent.agent_id] = session_messages[-1]
- # Resore the agents messages array including the optional prepend messages
- agent.messages = prepend_messsages + [session_message.to_message() for session_message in session_messages]
+ # Restore the agents messages array including the optional prepend messages
+ agent.messages = prepend_messages + [session_message.to_message() for session_message in session_messages]
diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py
index 8f8423828..0cc0a68c1 100644
--- a/src/strands/session/s3_session_manager.py
+++ b/src/strands/session/s3_session_manager.py
@@ -24,6 +24,7 @@ class S3SessionManager(RepositorySessionManager, SessionRepository):
"""S3-based session manager for cloud storage.
Creates the following filesystem structure for the session storage:
+ ```bash
//
└── session_/
├── session.json # Session metadata
@@ -33,7 +34,7 @@ class S3SessionManager(RepositorySessionManager, SessionRepository):
└── messages/
├── message_.json
└── message_.json
-
+ ```
"""
def __init__(
diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py
index eebffef29..802865189 100644
--- a/src/strands/telemetry/tracer.py
+++ b/src/strands/telemetry/tracer.py
@@ -273,7 +273,7 @@ def end_model_invoke_span(
self._end_span(span, attributes, error)
- def start_tool_call_span(self, tool: ToolUse, parent_span: Optional[Span] = None, **kwargs: Any) -> Optional[Span]:
+ def start_tool_call_span(self, tool: ToolUse, parent_span: Optional[Span] = None, **kwargs: Any) -> Span:
"""Start a new span for a tool call.
Args:
diff --git a/src/strands/tools/executor.py b/src/strands/tools/executor.py
index 1214fa608..d90f9a5aa 100644
--- a/src/strands/tools/executor.py
+++ b/src/strands/tools/executor.py
@@ -5,7 +5,7 @@
import time
from typing import Any, Optional, cast
-from opentelemetry import trace
+from opentelemetry import trace as trace_api
from ..telemetry.metrics import EventLoopMetrics, Trace
from ..telemetry.tracer import get_tracer
@@ -23,7 +23,7 @@ async def run_tools(
invalid_tool_use_ids: list[str],
tool_results: list[ToolResult],
cycle_trace: Trace,
- parent_span: Optional[trace.Span] = None,
+ parent_span: Optional[trace_api.Span] = None,
) -> ToolGenerator:
"""Execute tools concurrently.
@@ -53,24 +53,23 @@ async def work(
tool_name = tool_use["name"]
tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name)
tool_start_time = time.time()
+ with trace_api.use_span(tool_call_span):
+ try:
+ async for event in handler(tool_use):
+ worker_queue.put_nowait((worker_id, event))
+ await worker_event.wait()
+ worker_event.clear()
+
+ result = cast(ToolResult, event)
+ finally:
+ worker_queue.put_nowait((worker_id, stop_event))
+
+ tool_success = result.get("status") == "success"
+ tool_duration = time.time() - tool_start_time
+ message = Message(role="user", content=[{"toolResult": result}])
+ event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message)
+ cycle_trace.add_child(tool_trace)
- try:
- async for event in handler(tool_use):
- worker_queue.put_nowait((worker_id, event))
- await worker_event.wait()
- worker_event.clear()
-
- result = cast(ToolResult, event)
- finally:
- worker_queue.put_nowait((worker_id, stop_event))
-
- tool_success = result.get("status") == "success"
- tool_duration = time.time() - tool_start_time
- message = Message(role="user", content=[{"toolResult": result}])
- event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message)
- cycle_trace.add_child(tool_trace)
-
- if tool_call_span:
tracer.end_tool_call_span(tool_call_span, result)
return result
diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py
index 4cf4e1f85..8c21baa4a 100644
--- a/src/strands/tools/mcp/mcp_client.py
+++ b/src/strands/tools/mcp/mcp_client.py
@@ -20,15 +20,16 @@
from mcp import ClientSession, ListToolsResult
from mcp.types import CallToolResult as MCPCallToolResult
+from mcp.types import GetPromptResult, ListPromptsResult
from mcp.types import ImageContent as MCPImageContent
from mcp.types import TextContent as MCPTextContent
from ...types import PaginatedList
from ...types.exceptions import MCPClientInitializationError
from ...types.media import ImageFormat
-from ...types.tools import ToolResult, ToolResultContent, ToolResultStatus
+from ...types.tools import ToolResultContent, ToolResultStatus
from .mcp_agent_tool import MCPAgentTool
-from .mcp_types import MCPTransport
+from .mcp_types import MCPToolResult, MCPTransport
logger = logging.getLogger(__name__)
@@ -57,7 +58,8 @@ class MCPClient:
It handles the creation, initialization, and cleanup of MCP connections.
The connection runs in a background thread to avoid blocking the main application thread
- while maintaining communication with the MCP service.
+ while maintaining communication with the MCP service. When structured content is available
+ from MCP tools, it will be returned as the last item in the content array of the ToolResult.
"""
def __init__(self, transport_callable: Callable[[], MCPTransport]):
@@ -164,17 +166,67 @@ async def _list_tools_async() -> ListToolsResult:
self._log_debug_with_thread("successfully adapted %d MCP tools", len(mcp_tools))
return PaginatedList[MCPAgentTool](mcp_tools, token=list_tools_response.nextCursor)
+ def list_prompts_sync(self, pagination_token: Optional[str] = None) -> ListPromptsResult:
+ """Synchronously retrieves the list of available prompts from the MCP server.
+
+ This method calls the asynchronous list_prompts method on the MCP session
+ and returns the raw ListPromptsResult with pagination support.
+
+ Args:
+ pagination_token: Optional token for pagination
+
+ Returns:
+ ListPromptsResult: The raw MCP response containing prompts and pagination info
+ """
+ self._log_debug_with_thread("listing MCP prompts synchronously")
+ if not self._is_session_active():
+ raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
+
+ async def _list_prompts_async() -> ListPromptsResult:
+ return await self._background_thread_session.list_prompts(cursor=pagination_token)
+
+ list_prompts_result: ListPromptsResult = self._invoke_on_background_thread(_list_prompts_async()).result()
+ self._log_debug_with_thread("received %d prompts from MCP server", len(list_prompts_result.prompts))
+ for prompt in list_prompts_result.prompts:
+ self._log_debug_with_thread(prompt.name)
+
+ return list_prompts_result
+
+ def get_prompt_sync(self, prompt_id: str, args: dict[str, Any]) -> GetPromptResult:
+ """Synchronously retrieves a prompt from the MCP server.
+
+ Args:
+ prompt_id: The ID of the prompt to retrieve
+ args: Optional arguments to pass to the prompt
+
+ Returns:
+ GetPromptResult: The prompt response from the MCP server
+ """
+ self._log_debug_with_thread("getting MCP prompt synchronously")
+ if not self._is_session_active():
+ raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
+
+ async def _get_prompt_async() -> GetPromptResult:
+ return await self._background_thread_session.get_prompt(prompt_id, arguments=args)
+
+ get_prompt_result: GetPromptResult = self._invoke_on_background_thread(_get_prompt_async()).result()
+ self._log_debug_with_thread("received prompt from MCP server")
+
+ return get_prompt_result
+
def call_tool_sync(
self,
tool_use_id: str,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
- ) -> ToolResult:
+ ) -> MCPToolResult:
"""Synchronously calls a tool on the MCP server.
This method calls the asynchronous call_tool method on the MCP session
- and converts the result to the ToolResult format.
+ and converts the result to the ToolResult format. If the MCP tool returns
+ structured content, it will be included as the last item in the content array
+ of the returned ToolResult.
Args:
tool_use_id: Unique identifier for this tool use
@@ -183,7 +235,7 @@ def call_tool_sync(
read_timeout_seconds: Optional timeout for the tool call
Returns:
- ToolResult: The result of the tool call
+ MCPToolResult: The result of the tool call
"""
self._log_debug_with_thread("calling MCP tool '%s' synchronously with tool_use_id=%s", name, tool_use_id)
if not self._is_session_active():
@@ -205,11 +257,11 @@ async def call_tool_async(
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
- ) -> ToolResult:
+ ) -> MCPToolResult:
"""Asynchronously calls a tool on the MCP server.
This method calls the asynchronous call_tool method on the MCP session
- and converts the result to the ToolResult format.
+ and converts the result to the MCPToolResult format.
Args:
tool_use_id: Unique identifier for this tool use
@@ -218,7 +270,7 @@ async def call_tool_async(
read_timeout_seconds: Optional timeout for the tool call
Returns:
- ToolResult: The result of the tool call
+ MCPToolResult: The result of the tool call
"""
self._log_debug_with_thread("calling MCP tool '%s' asynchronously with tool_use_id=%s", name, tool_use_id)
if not self._is_session_active():
@@ -235,15 +287,27 @@ async def _call_tool_async() -> MCPCallToolResult:
logger.exception("tool execution failed")
return self._handle_tool_execution_error(tool_use_id, e)
- def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -> ToolResult:
+ def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -> MCPToolResult:
"""Create error ToolResult with consistent logging."""
- return ToolResult(
+ return MCPToolResult(
status="error",
toolUseId=tool_use_id,
content=[{"text": f"Tool execution failed: {str(exception)}"}],
)
- def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolResult) -> ToolResult:
+ def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolResult) -> MCPToolResult:
+ """Maps MCP tool result to the agent's MCPToolResult format.
+
+ This method processes the content from the MCP tool call result and converts it to the format
+ expected by the framework.
+
+ Args:
+ tool_use_id: Unique identifier for this tool use
+ call_tool_result: The result from the MCP tool call
+
+ Returns:
+ MCPToolResult: The converted tool result
+ """
self._log_debug_with_thread("received tool result with %d content items", len(call_tool_result.content))
mapped_content = [
@@ -254,7 +318,15 @@ def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolRes
status: ToolResultStatus = "error" if call_tool_result.isError else "success"
self._log_debug_with_thread("tool execution completed with status: %s", status)
- return ToolResult(status=status, toolUseId=tool_use_id, content=mapped_content)
+ result = MCPToolResult(
+ status=status,
+ toolUseId=tool_use_id,
+ content=mapped_content,
+ )
+ if call_tool_result.structuredContent:
+ result["structuredContent"] = call_tool_result.structuredContent
+
+ return result
async def _async_background_thread(self) -> None:
"""Asynchronous method that runs in the background thread to manage the MCP connection.
diff --git a/src/strands/tools/mcp/mcp_types.py b/src/strands/tools/mcp/mcp_types.py
index 30defc585..5fafed5dc 100644
--- a/src/strands/tools/mcp/mcp_types.py
+++ b/src/strands/tools/mcp/mcp_types.py
@@ -1,11 +1,15 @@
"""Type definitions for MCP integration."""
from contextlib import AbstractAsyncContextManager
+from typing import Any, Dict
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp.client.streamable_http import GetSessionIdCallback
from mcp.shared.memory import MessageStream
from mcp.shared.message import SessionMessage
+from typing_extensions import NotRequired
+
+from strands.types.tools import ToolResult
"""
MCPTransport defines the interface for MCP transport implementations. This abstracts
@@ -41,3 +45,19 @@ async def my_transport_implementation():
MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], GetSessionIdCallback
]
MCPTransport = AbstractAsyncContextManager[MessageStream | _MessageStreamWithGetSessionIdCallback]
+
+
+class MCPToolResult(ToolResult):
+ """Result of an MCP tool execution.
+
+ Extends the base ToolResult with MCP-specific structured content support.
+ The structuredContent field contains optional JSON data returned by MCP tools
+ that provides structured results beyond the standard text/image/document content.
+
+ Attributes:
+ structuredContent: Optional JSON object containing structured data returned
+ by the MCP tool. This allows MCP tools to return complex data structures
+ that can be processed programmatically by agents or other tools.
+ """
+
+ structuredContent: NotRequired[Dict[str, Any]]
diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py
index 9d835d28e..fd395ae77 100644
--- a/src/strands/tools/registry.py
+++ b/src/strands/tools/registry.py
@@ -11,7 +11,7 @@
from importlib import import_module, util
from os.path import expanduser
from pathlib import Path
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, Iterable, List, Optional
from typing_extensions import TypedDict, cast
@@ -54,7 +54,7 @@ def process_tools(self, tools: List[Any]) -> List[str]:
"""
tool_names = []
- for tool in tools:
+ def add_tool(tool: Any) -> None:
# Case 1: String file path
if isinstance(tool, str):
# Extract tool name from path
@@ -97,9 +97,16 @@ def process_tools(self, tools: List[Any]) -> List[str]:
elif isinstance(tool, AgentTool):
self.register_tool(tool)
tool_names.append(tool.tool_name)
+ # Case 6: Nested iterable (list, tuple, etc.) - add each sub-tool
+ elif isinstance(tool, Iterable) and not isinstance(tool, (str, bytes, bytearray)):
+ for t in tool:
+ add_tool(t)
else:
logger.warning("tool=<%s> | unrecognized tool specification", tool)
+ for a_tool in tools:
+ add_tool(a_tool)
+
return tool_names
def load_tool_from_filepath(self, tool_name: str, tool_path: str) -> None:
diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py
index 4bd3fd88e..90f2b8d7f 100644
--- a/src/strands/types/exceptions.py
+++ b/src/strands/types/exceptions.py
@@ -18,6 +18,23 @@ def __init__(self, original_exception: Exception, request_state: Any = None) ->
super().__init__(str(original_exception))
+class MaxTokensReachedException(Exception):
+ """Exception raised when the model reaches its maximum token generation limit.
+
+ This exception is raised when the model stops generating tokens because it has reached the maximum number of
+ tokens allowed for output generation. This can occur when the model's max_tokens parameter is set too low for
+ the complexity of the response, or when the model naturally reaches its configured output limit during generation.
+ """
+
+ def __init__(self, message: str):
+ """Initialize the exception with an error message and the incomplete message object.
+
+ Args:
+ message: The error message describing the token limit issue
+ """
+ super().__init__(message)
+
+
class ContextWindowOverflowException(Exception):
"""Exception raised when the context window is exceeded.
diff --git a/src/strands/types/session.py b/src/strands/types/session.py
index 259ab1171..e51816f74 100644
--- a/src/strands/types/session.py
+++ b/src/strands/types/session.py
@@ -125,7 +125,7 @@ def from_agent(cls, agent: "Agent") -> "SessionAgent":
@classmethod
def from_dict(cls, env: dict[str, Any]) -> "SessionAgent":
- """Initialize a SessionAgent from a dictionary, ignoring keys that are not calss parameters."""
+ """Initialize a SessionAgent from a dictionary, ignoring keys that are not class parameters."""
return cls(**{k: v for k, v in env.items() if k in inspect.signature(cls).parameters})
def to_dict(self) -> dict[str, Any]:
@@ -144,7 +144,7 @@ class Session:
@classmethod
def from_dict(cls, env: dict[str, Any]) -> "Session":
- """Initialize a Session from a dictionary, ignoring keys that are not calss parameters."""
+ """Initialize a Session from a dictionary, ignoring keys that are not class parameters."""
return cls(**{k: v for k, v in env.items() if k in inspect.signature(cls).parameters})
def to_dict(self) -> dict[str, Any]:
diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py
index d6471a09a..4e310dace 100644
--- a/tests/strands/agent/test_agent.py
+++ b/tests/strands/agent/test_agent.py
@@ -231,6 +231,25 @@ def test_agent__init__with_string_model_id():
assert agent.model.config["model_id"] == "nonsense"
+def test_agent__init__nested_tools_flattening(tool_decorated, tool_module, tool_imported, tool_registry):
+ _ = tool_registry
+ # Nested structure: [tool_decorated, [tool_module, [tool_imported]]]
+ agent = Agent(tools=[tool_decorated, [tool_module, [tool_imported]]])
+ tru_tool_names = sorted(agent.tool_names)
+ exp_tool_names = ["tool_decorated", "tool_imported", "tool_module"]
+ assert tru_tool_names == exp_tool_names
+
+
+def test_agent__init__deeply_nested_tools(tool_decorated, tool_module, tool_imported, tool_registry):
+ _ = tool_registry
+ # Deeply nested structure
+ nested_tools = [[[[tool_decorated]], [[tool_module]], tool_imported]]
+ agent = Agent(tools=nested_tools)
+ tru_tool_names = sorted(agent.tool_names)
+ exp_tool_names = ["tool_decorated", "tool_imported", "tool_module"]
+ assert tru_tool_names == exp_tool_names
+
+
def test_agent__call__(
mock_model,
system_prompt,
diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py
index 1ac2f8258..191ab51ba 100644
--- a/tests/strands/event_loop/test_event_loop.py
+++ b/tests/strands/event_loop/test_event_loop.py
@@ -19,7 +19,12 @@
)
from strands.telemetry.metrics import EventLoopMetrics
from strands.tools.registry import ToolRegistry
-from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException
+from strands.types.exceptions import (
+ ContextWindowOverflowException,
+ EventLoopException,
+ MaxTokensReachedException,
+ ModelThrottledException,
+)
from tests.fixtures.mock_hook_provider import MockHookProvider
@@ -300,8 +305,10 @@ async def test_event_loop_cycle_text_response_error(
await alist(stream)
+@patch("strands.event_loop.event_loop.recover_message_on_max_tokens_reached")
@pytest.mark.asyncio
async def test_event_loop_cycle_tool_result(
+ mock_recover_message,
agent,
model,
system_prompt,
@@ -334,6 +341,9 @@ async def test_event_loop_cycle_tool_result(
assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state
+ # Verify that recover_message_on_max_tokens_reached was NOT called for tool_use stop reason
+ mock_recover_message.assert_not_called()
+
model.stream.assert_called_with(
[
{"role": "user", "content": [{"text": "Hello"}]},
@@ -556,6 +566,53 @@ async def test_event_loop_tracing_with_model_error(
mock_tracer.end_span_with_error.assert_called_once_with(model_span, "Input too long", model.stream.side_effect)
+@pytest.mark.asyncio
+async def test_event_loop_cycle_max_tokens_exception(
+ agent,
+ model,
+ agenerator,
+ alist,
+):
+ """Test that max_tokens stop reason calls _recover_message_on_max_tokens_reached then MaxTokensReachedException."""
+
+ model.stream.side_effect = [
+ agenerator(
+ [
+ {
+ "contentBlockStart": {
+ "start": {
+ "toolUse": {
+ "toolUseId": "t1",
+ "name": "asdf",
+ "input": {}, # empty
+ },
+ },
+ },
+ },
+ {"contentBlockStop": {}},
+ {"messageStop": {"stopReason": "max_tokens"}},
+ ]
+ ),
+ ]
+
+ # Call event_loop_cycle, expecting it to raise MaxTokensReachedException
+ expected_message = (
+ "Agent has reached an unrecoverable state due to max_tokens limit. "
+ "For more information see: "
+ "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception"
+ )
+ with pytest.raises(MaxTokensReachedException, match=expected_message):
+ stream = strands.event_loop.event_loop.event_loop_cycle(
+ agent=agent,
+ invocation_state={},
+ )
+ await alist(stream)
+
+ # Verify the exception message contains the expected content
+ assert len(agent.messages) == 2
+ assert "tool use was incomplete due" in agent.messages[1]["content"][0]["text"]
+
+
@patch("strands.event_loop.event_loop.get_tracer")
@pytest.mark.asyncio
async def test_event_loop_tracing_with_tool_execution(
diff --git a/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py b/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py
new file mode 100644
index 000000000..402e90966
--- /dev/null
+++ b/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py
@@ -0,0 +1,269 @@
+"""Tests for token limit recovery utility."""
+
+from strands.event_loop._recover_message_on_max_tokens_reached import (
+ recover_message_on_max_tokens_reached,
+)
+from strands.types.content import Message
+
+
+def test_recover_message_on_max_tokens_reached_with_incomplete_tool_use():
+ """Test recovery when incomplete tool use is present in the message."""
+ incomplete_message: Message = {
+ "role": "assistant",
+ "content": [
+ {"text": "I'll help you with that."},
+ {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId
+ ],
+ }
+
+ result = recover_message_on_max_tokens_reached(incomplete_message)
+
+ # Check the corrected message content
+ assert result["role"] == "assistant"
+ assert len(result["content"]) == 2
+
+ # First content block should be preserved
+ assert result["content"][0] == {"text": "I'll help you with that."}
+
+ # Second content block should be replaced with error message
+ assert "text" in result["content"][1]
+ assert "calculator" in result["content"][1]["text"]
+ assert "incomplete due to maximum token limits" in result["content"][1]["text"]
+
+
+def test_recover_message_on_max_tokens_reached_with_missing_tool_name():
+ """Test recovery when tool use has no name."""
+ incomplete_message: Message = {
+ "role": "assistant",
+ "content": [
+ {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Missing name
+ ],
+ }
+
+ result = recover_message_on_max_tokens_reached(incomplete_message)
+
+ # Check the corrected message content
+ assert result["role"] == "assistant"
+ assert len(result["content"]) == 1
+
+ # Content should be replaced with error message using
+ assert "text" in result["content"][0]
+ assert "" in result["content"][0]["text"]
+ assert "incomplete due to maximum token limits" in result["content"][0]["text"]
+
+
+def test_recover_message_on_max_tokens_reached_with_missing_input():
+ """Test recovery when tool use has no input."""
+ incomplete_message: Message = {
+ "role": "assistant",
+ "content": [
+ {"toolUse": {"name": "calculator", "toolUseId": "123"}}, # Missing input
+ ],
+ }
+
+ result = recover_message_on_max_tokens_reached(incomplete_message)
+
+ # Check the corrected message content
+ assert result["role"] == "assistant"
+ assert len(result["content"]) == 1
+
+ # Content should be replaced with error message
+ assert "text" in result["content"][0]
+ assert "calculator" in result["content"][0]["text"]
+ assert "incomplete due to maximum token limits" in result["content"][0]["text"]
+
+
+def test_recover_message_on_max_tokens_reached_with_missing_tool_use_id():
+ """Test recovery when tool use has no toolUseId."""
+ incomplete_message: Message = {
+ "role": "assistant",
+ "content": [
+ {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}}}, # Missing toolUseId
+ ],
+ }
+
+ result = recover_message_on_max_tokens_reached(incomplete_message)
+
+ # Check the corrected message content
+ assert result["role"] == "assistant"
+ assert len(result["content"]) == 1
+
+ # Content should be replaced with error message
+ assert "text" in result["content"][0]
+ assert "calculator" in result["content"][0]["text"]
+ assert "incomplete due to maximum token limits" in result["content"][0]["text"]
+
+
+def test_recover_message_on_max_tokens_reached_with_valid_tool_use():
+ """Test that even valid tool uses are replaced with error messages."""
+ complete_message: Message = {
+ "role": "assistant",
+ "content": [
+ {"text": "I'll help you with that."},
+ {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}, "toolUseId": "123"}}, # Valid
+ ],
+ }
+
+ result = recover_message_on_max_tokens_reached(complete_message)
+
+ # Should replace even valid tool uses with error messages
+ assert result["role"] == "assistant"
+ assert len(result["content"]) == 2
+ assert result["content"][0] == {"text": "I'll help you with that."}
+
+ # Valid tool use should also be replaced with error message
+ assert "text" in result["content"][1]
+ assert "calculator" in result["content"][1]["text"]
+ assert "incomplete due to maximum token limits" in result["content"][1]["text"]
+
+
+def test_recover_message_on_max_tokens_reached_with_empty_content():
+ """Test handling of message with empty content."""
+ empty_message: Message = {"role": "assistant", "content": []}
+
+ result = recover_message_on_max_tokens_reached(empty_message)
+
+ # Should return message with empty content preserved
+ assert result["role"] == "assistant"
+ assert result["content"] == []
+
+
+def test_recover_message_on_max_tokens_reached_with_none_content():
+ """Test handling of message with None content."""
+ none_content_message: Message = {"role": "assistant", "content": None}
+
+ result = recover_message_on_max_tokens_reached(none_content_message)
+
+ # Should return message with empty content
+ assert result["role"] == "assistant"
+ assert result["content"] == []
+
+
+def test_recover_message_on_max_tokens_reached_with_mixed_content():
+ """Test recovery with mix of valid content and incomplete tool use."""
+ incomplete_message: Message = {
+ "role": "assistant",
+ "content": [
+ {"text": "Let me calculate this for you."},
+ {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Incomplete
+ {"text": "And then I'll explain the result."},
+ ],
+ }
+
+ result = recover_message_on_max_tokens_reached(incomplete_message)
+
+ # Check the corrected message content
+ assert result["role"] == "assistant"
+ assert len(result["content"]) == 3
+
+ # First and third content blocks should be preserved
+ assert result["content"][0] == {"text": "Let me calculate this for you."}
+ assert result["content"][2] == {"text": "And then I'll explain the result."}
+
+ # Second content block should be replaced with error message
+ assert "text" in result["content"][1]
+ assert "calculator" in result["content"][1]["text"]
+ assert "incomplete due to maximum token limits" in result["content"][1]["text"]
+
+
+def test_recover_message_on_max_tokens_reached_preserves_non_tool_content():
+ """Test that non-tool content is preserved as-is."""
+ incomplete_message: Message = {
+ "role": "assistant",
+ "content": [
+ {"text": "Here's some text."},
+ {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}},
+ {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Incomplete
+ ],
+ }
+
+ result = recover_message_on_max_tokens_reached(incomplete_message)
+
+ # Check the corrected message content
+ assert result["role"] == "assistant"
+ assert len(result["content"]) == 3
+
+ # First two content blocks should be preserved exactly
+ assert result["content"][0] == {"text": "Here's some text."}
+ assert result["content"][1] == {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}}
+
+ # Third content block should be replaced with error message
+ assert "text" in result["content"][2]
+ assert "" in result["content"][2]["text"]
+ assert "incomplete due to maximum token limits" in result["content"][2]["text"]
+
+
+def test_recover_message_on_max_tokens_reached_multiple_incomplete_tools():
+ """Test recovery with multiple incomplete tool uses."""
+ incomplete_message: Message = {
+ "role": "assistant",
+ "content": [
+ {"toolUse": {"name": "calculator", "input": {}}}, # Missing toolUseId
+ {"text": "Some text in between."},
+ {"toolUse": {"name": "", "input": {}, "toolUseId": "456"}}, # Missing name
+ ],
+ }
+
+ result = recover_message_on_max_tokens_reached(incomplete_message)
+
+ # Check the corrected message content
+ assert result["role"] == "assistant"
+ assert len(result["content"]) == 3
+
+ # First tool use should be replaced
+ assert "text" in result["content"][0]
+ assert "calculator" in result["content"][0]["text"]
+ assert "incomplete due to maximum token limits" in result["content"][0]["text"]
+
+ # Text content should be preserved
+ assert result["content"][1] == {"text": "Some text in between."}
+
+ # Second tool use should be replaced with
+ assert "text" in result["content"][2]
+ assert "" in result["content"][2]["text"]
+ assert "incomplete due to maximum token limits" in result["content"][2]["text"]
+
+
+def test_recover_message_on_max_tokens_reached_preserves_user_role():
+ """Test that the function preserves the original message role."""
+ incomplete_message: Message = {
+ "role": "user",
+ "content": [
+ {"toolUse": {"name": "calculator", "input": {}}}, # Missing toolUseId
+ ],
+ }
+
+ result = recover_message_on_max_tokens_reached(incomplete_message)
+
+ # Should preserve the original role
+ assert result["role"] == "user"
+ assert len(result["content"]) == 1
+ assert "text" in result["content"][0]
+ assert "calculator" in result["content"][0]["text"]
+
+
+def test_recover_message_on_max_tokens_reached_with_content_without_tool_use():
+ """Test handling of content blocks that don't have toolUse key."""
+ message: Message = {
+ "role": "assistant",
+ "content": [
+ {"text": "Regular text content."},
+ {"someOtherKey": "someValue"}, # Content without toolUse
+ {"toolUse": {"name": "calculator"}}, # Incomplete tool use
+ ],
+ }
+
+ result = recover_message_on_max_tokens_reached(message)
+
+ # Check the corrected message content
+ assert result["role"] == "assistant"
+ assert len(result["content"]) == 3
+
+ # First two content blocks should be preserved
+ assert result["content"][0] == {"text": "Regular text content."}
+ assert result["content"][1] == {"someOtherKey": "someValue"}
+
+ # Third content block should be replaced with error message
+ assert "text" in result["content"][2]
+ assert "calculator" in result["content"][2]["text"]
+ assert "incomplete due to maximum token limits" in result["content"][2]["text"]
diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py
index 47e028cb9..0a2846adf 100644
--- a/tests/strands/models/test_bedrock.py
+++ b/tests/strands/models/test_bedrock.py
@@ -13,6 +13,7 @@
from strands.models import BedrockModel
from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION
from strands.types.exceptions import ModelThrottledException
+from strands.types.tools import ToolSpec
@pytest.fixture
@@ -51,7 +52,7 @@ def model(bedrock_client, model_id):
@pytest.fixture
def messages():
- return [{"role": "user", "content": {"text": "test"}}]
+ return [{"role": "user", "content": [{"text": "test"}]}]
@pytest.fixture
@@ -90,8 +91,12 @@ def inference_config():
@pytest.fixture
-def tool_spec():
- return {"t1": 1}
+def tool_spec() -> ToolSpec:
+ return {
+ "description": "description",
+ "name": "name",
+ "inputSchema": {"key": "val"},
+ }
@pytest.fixture
@@ -750,7 +755,7 @@ async def test_stream_output_no_guardrail_redact(
@pytest.mark.asyncio
-async def test_stream_with_streaming_false(bedrock_client, alist):
+async def test_stream_with_streaming_false(bedrock_client, alist, messages):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {"message": {"role": "assistant", "content": [{"text": "test"}]}},
@@ -759,8 +764,7 @@ async def test_stream_with_streaming_false(bedrock_client, alist):
# Create model and call stream
model = BedrockModel(model_id="test-model", streaming=False)
- request = {"modelId": "test-model"}
- response = model.stream(request)
+ response = model.stream(messages)
tru_events = await alist(response)
exp_events = [
@@ -776,7 +780,7 @@ async def test_stream_with_streaming_false(bedrock_client, alist):
@pytest.mark.asyncio
-async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist):
+async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist, messages):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {
@@ -790,8 +794,7 @@ async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist):
# Create model and call stream
model = BedrockModel(model_id="test-model", streaming=False)
- request = {"modelId": "test-model"}
- response = model.stream(request)
+ response = model.stream(messages)
tru_events = await alist(response)
exp_events = [
@@ -808,7 +811,7 @@ async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist):
@pytest.mark.asyncio
-async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist):
+async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist, messages):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {
@@ -828,8 +831,7 @@ async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist):
# Create model and call stream
model = BedrockModel(model_id="test-model", streaming=False)
- request = {"modelId": "test-model"}
- response = model.stream(request)
+ response = model.stream(messages)
tru_events = await alist(response)
exp_events = [
@@ -847,7 +849,7 @@ async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist):
@pytest.mark.asyncio
-async def test_stream_and_reasoning_no_signature(bedrock_client, alist):
+async def test_stream_and_reasoning_no_signature(bedrock_client, alist, messages):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {
@@ -867,8 +869,7 @@ async def test_stream_and_reasoning_no_signature(bedrock_client, alist):
# Create model and call stream
model = BedrockModel(model_id="test-model", streaming=False)
- request = {"modelId": "test-model"}
- response = model.stream(request)
+ response = model.stream(messages)
tru_events = await alist(response)
exp_events = [
@@ -884,7 +885,7 @@ async def test_stream_and_reasoning_no_signature(bedrock_client, alist):
@pytest.mark.asyncio
-async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client, alist):
+async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client, alist, messages):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {"message": {"role": "assistant", "content": [{"text": "test"}]}},
@@ -895,8 +896,7 @@ async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client
# Create model and call stream
model = BedrockModel(model_id="test-model", streaming=False)
- request = {"modelId": "test-model"}
- response = model.stream(request)
+ response = model.stream(messages)
tru_events = await alist(response)
exp_events = [
@@ -919,7 +919,7 @@ async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client
@pytest.mark.asyncio
-async def test_stream_input_guardrails(bedrock_client, alist):
+async def test_stream_input_guardrails(bedrock_client, alist, messages):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {"message": {"role": "assistant", "content": [{"text": "test"}]}},
@@ -937,8 +937,7 @@ async def test_stream_input_guardrails(bedrock_client, alist):
# Create model and call stream
model = BedrockModel(model_id="test-model", streaming=False)
- request = {"modelId": "test-model"}
- response = model.stream(request)
+ response = model.stream(messages)
tru_events = await alist(response)
exp_events = [
@@ -970,7 +969,7 @@ async def test_stream_input_guardrails(bedrock_client, alist):
@pytest.mark.asyncio
-async def test_stream_output_guardrails(bedrock_client, alist):
+async def test_stream_output_guardrails(bedrock_client, alist, messages):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {"message": {"role": "assistant", "content": [{"text": "test"}]}},
@@ -989,8 +988,7 @@ async def test_stream_output_guardrails(bedrock_client, alist):
}
model = BedrockModel(model_id="test-model", streaming=False)
- request = {"modelId": "test-model"}
- response = model.stream(request)
+ response = model.stream(messages)
tru_events = await alist(response)
exp_events = [
@@ -1024,7 +1022,7 @@ async def test_stream_output_guardrails(bedrock_client, alist):
@pytest.mark.asyncio
-async def test_stream_output_guardrails_redacts_output(bedrock_client, alist):
+async def test_stream_output_guardrails_redacts_output(bedrock_client, alist, messages):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {"message": {"role": "assistant", "content": [{"text": "test"}]}},
@@ -1043,8 +1041,7 @@ async def test_stream_output_guardrails_redacts_output(bedrock_client, alist):
}
model = BedrockModel(model_id="test-model", streaming=False)
- request = {"modelId": "test-model"}
- response = model.stream(request)
+ response = model.stream(messages)
tru_events = await alist(response)
exp_events = [
@@ -1101,7 +1098,7 @@ async def test_structured_output(bedrock_client, model, test_output_model_cls, a
@pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)")
@pytest.mark.asyncio
-async def test_add_note_on_client_error(bedrock_client, model, alist):
+async def test_add_note_on_client_error(bedrock_client, model, alist, messages):
"""Test that add_note is called on ClientError with region and model ID information."""
# Mock the client error response
error_response = {"Error": {"Code": "ValidationException", "Message": "Some error message"}}
@@ -1109,13 +1106,13 @@ async def test_add_note_on_client_error(bedrock_client, model, alist):
# Call the stream method which should catch and add notes to the exception
with pytest.raises(ClientError) as err:
- await alist(model.stream({"modelId": "test-model"}))
+ await alist(model.stream(messages))
assert err.value.__notes__ == ["└ Bedrock region: us-west-2", "└ Model id: m1"]
@pytest.mark.asyncio
-async def test_no_add_note_when_not_available(bedrock_client, model, alist):
+async def test_no_add_note_when_not_available(bedrock_client, model, alist, messages):
"""Verify that on any python version (even < 3.11 where add_note is not available, we get the right exception)."""
# Mock the client error response
error_response = {"Error": {"Code": "ValidationException", "Message": "Some error message"}}
@@ -1123,12 +1120,12 @@ async def test_no_add_note_when_not_available(bedrock_client, model, alist):
# Call the stream method which should catch and add notes to the exception
with pytest.raises(ClientError):
- await alist(model.stream({"modelId": "test-model"}))
+ await alist(model.stream(messages))
@pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)")
@pytest.mark.asyncio
-async def test_add_note_on_access_denied_exception(bedrock_client, model, alist):
+async def test_add_note_on_access_denied_exception(bedrock_client, model, alist, messages):
"""Test that add_note adds documentation link for AccessDeniedException."""
# Mock the client error response for access denied
error_response = {
@@ -1142,7 +1139,7 @@ async def test_add_note_on_access_denied_exception(bedrock_client, model, alist)
# Call the stream method which should catch and add notes to the exception
with pytest.raises(ClientError) as err:
- await alist(model.stream({"modelId": "test-model"}))
+ await alist(model.stream(messages))
assert err.value.__notes__ == [
"└ Bedrock region: us-west-2",
@@ -1154,7 +1151,7 @@ async def test_add_note_on_access_denied_exception(bedrock_client, model, alist)
@pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)")
@pytest.mark.asyncio
-async def test_add_note_on_validation_exception_throughput(bedrock_client, model, alist):
+async def test_add_note_on_validation_exception_throughput(bedrock_client, model, alist, messages):
"""Test that add_note adds documentation link for ValidationException about on-demand throughput."""
# Mock the client error response for validation exception
error_response = {
@@ -1170,7 +1167,7 @@ async def test_add_note_on_validation_exception_throughput(bedrock_client, model
# Call the stream method which should catch and add notes to the exception
with pytest.raises(ClientError) as err:
- await alist(model.stream({"modelId": "test-model"}))
+ await alist(model.stream(messages))
assert err.value.__notes__ == [
"└ Bedrock region: us-west-2",
@@ -1202,3 +1199,32 @@ async def test_stream_logging(bedrock_client, model, messages, caplog, alist):
assert "invoking model" in log_text
assert "got response from model" in log_text
assert "finished streaming response from model" in log_text
+
+
+def test_format_request_cleans_tool_result_content_blocks(model, model_id):
+ """Test that format_request cleans toolResult blocks by removing extra fields."""
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "toolResult": {
+ "content": [{"text": "Tool output"}],
+ "toolUseId": "tool123",
+ "status": "success",
+ "extraField": "should be removed",
+ "mcpMetadata": {"server": "test"},
+ }
+ },
+ ],
+ }
+ ]
+
+ formatted_request = model.format_request(messages)
+
+ # Verify toolResult only contains allowed fields in the formatted request
+ tool_result = formatted_request["messages"][0]["content"][0]["toolResult"]
+ expected = {"content": [{"text": "Tool output"}], "toolUseId": "tool123", "status": "success"}
+ assert tool_result == expected
+ assert "extraField" not in tool_result
+ assert "mcpMetadata" not in tool_result
diff --git a/tests/strands/models/test_sagemaker.py b/tests/strands/models/test_sagemaker.py
new file mode 100644
index 000000000..ba395b2d6
--- /dev/null
+++ b/tests/strands/models/test_sagemaker.py
@@ -0,0 +1,574 @@
+"""Tests for the Amazon SageMaker model provider."""
+
+import json
+import unittest.mock
+from typing import Any, Dict, List
+
+import boto3
+import pytest
+from botocore.config import Config as BotocoreConfig
+
+from strands.models.sagemaker import (
+ FunctionCall,
+ SageMakerAIModel,
+ ToolCall,
+ UsageMetadata,
+)
+from strands.types.content import Messages
+from strands.types.tools import ToolSpec
+
+
+@pytest.fixture
+def boto_session():
+ """Mock boto3 session."""
+ with unittest.mock.patch.object(boto3, "Session") as mock_session:
+ yield mock_session.return_value
+
+
+@pytest.fixture
+def sagemaker_client(boto_session):
+ """Mock SageMaker runtime client."""
+ return boto_session.client.return_value
+
+
+@pytest.fixture
+def endpoint_config() -> Dict[str, Any]:
+ """Default endpoint configuration for tests."""
+ return {
+ "endpoint_name": "test-endpoint",
+ "inference_component_name": "test-component",
+ "region_name": "us-east-1",
+ }
+
+
+@pytest.fixture
+def payload_config() -> Dict[str, Any]:
+ """Default payload configuration for tests."""
+ return {
+ "max_tokens": 1024,
+ "temperature": 0.7,
+ "stream": True,
+ }
+
+
+@pytest.fixture
+def model(boto_session, endpoint_config, payload_config):
+ """SageMaker model instance with mocked boto session."""
+ return SageMakerAIModel(endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session)
+
+
+@pytest.fixture
+def messages() -> Messages:
+ """Sample messages for testing."""
+ return [{"role": "user", "content": [{"text": "What is the capital of France?"}]}]
+
+
+@pytest.fixture
+def tool_specs() -> List[ToolSpec]:
+ """Sample tool specifications for testing."""
+ return [
+ {
+ "name": "get_weather",
+ "description": "Get the weather for a location",
+ "inputSchema": {
+ "json": {
+ "type": "object",
+ "properties": {"location": {"type": "string"}},
+ "required": ["location"],
+ }
+ },
+ }
+ ]
+
+
+@pytest.fixture
+def system_prompt() -> str:
+ """Sample system prompt for testing."""
+ return "You are a helpful assistant."
+
+
+class TestSageMakerAIModel:
+ """Test suite for SageMakerAIModel."""
+
+ def test_init_default(self, boto_session):
+ """Test initialization with default parameters."""
+ endpoint_config = {"endpoint_name": "test-endpoint", "region_name": "us-east-1"}
+ payload_config = {"max_tokens": 1024}
+ model = SageMakerAIModel(
+ endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session
+ )
+
+ assert model.endpoint_config["endpoint_name"] == "test-endpoint"
+ assert model.payload_config.get("stream", True) is True
+
+ boto_session.client.assert_called_once_with(
+ service_name="sagemaker-runtime",
+ config=unittest.mock.ANY,
+ )
+
+ def test_init_with_all_params(self, boto_session):
+ """Test initialization with all parameters."""
+ endpoint_config = {
+ "endpoint_name": "test-endpoint",
+ "inference_component_name": "test-component",
+ "region_name": "us-west-2",
+ }
+ payload_config = {
+ "stream": False,
+ "max_tokens": 1024,
+ "temperature": 0.7,
+ }
+ client_config = BotocoreConfig(user_agent_extra="test-agent")
+
+ model = SageMakerAIModel(
+ endpoint_config=endpoint_config,
+ payload_config=payload_config,
+ boto_session=boto_session,
+ boto_client_config=client_config,
+ )
+
+ assert model.endpoint_config["endpoint_name"] == "test-endpoint"
+ assert model.endpoint_config["inference_component_name"] == "test-component"
+ assert model.payload_config["stream"] is False
+ assert model.payload_config["max_tokens"] == 1024
+ assert model.payload_config["temperature"] == 0.7
+
+ boto_session.client.assert_called_once_with(
+ service_name="sagemaker-runtime",
+ config=unittest.mock.ANY,
+ )
+
+ def test_init_with_client_config(self, boto_session):
+ """Test initialization with client configuration."""
+ endpoint_config = {"endpoint_name": "test-endpoint", "region_name": "us-east-1"}
+ payload_config = {"max_tokens": 1024}
+ client_config = BotocoreConfig(user_agent_extra="test-agent")
+
+ SageMakerAIModel(
+ endpoint_config=endpoint_config,
+ payload_config=payload_config,
+ boto_session=boto_session,
+ boto_client_config=client_config,
+ )
+
+ # Verify client was created with a config that includes our user agent
+ boto_session.client.assert_called_once_with(
+ service_name="sagemaker-runtime",
+ config=unittest.mock.ANY,
+ )
+
+ # Get the actual config passed to client
+ actual_config = boto_session.client.call_args[1]["config"]
+ assert "strands-agents" in actual_config.user_agent_extra
+ assert "test-agent" in actual_config.user_agent_extra
+
+ def test_update_config(self, model):
+ """Test updating model configuration."""
+ new_config = {"target_model": "new-model", "target_variant": "new-variant"}
+ model.update_config(**new_config)
+
+ assert model.endpoint_config["target_model"] == "new-model"
+ assert model.endpoint_config["target_variant"] == "new-variant"
+ # Original values should be preserved
+ assert model.endpoint_config["endpoint_name"] == "test-endpoint"
+ assert model.endpoint_config["inference_component_name"] == "test-component"
+
+ def test_get_config(self, model, endpoint_config):
+ """Test getting model configuration."""
+ config = model.get_config()
+ assert config == model.endpoint_config
+ assert isinstance(config, dict)
+
+ # def test_format_request_messages_with_system_prompt(self, model):
+ # """Test formatting request messages with system prompt."""
+ # messages = [{"role": "user", "content": "Hello"}]
+ # system_prompt = "You are a helpful assistant."
+
+ # formatted_messages = model.format_request_messages(messages, system_prompt)
+
+ # assert len(formatted_messages) == 2
+ # assert formatted_messages[0]["role"] == "system"
+ # assert formatted_messages[0]["content"] == system_prompt
+ # assert formatted_messages[1]["role"] == "user"
+ # assert formatted_messages[1]["content"] == "Hello"
+
+ # def test_format_request_messages_with_tool_calls(self, model):
+ # """Test formatting request messages with tool calls."""
+ # messages = [
+ # {"role": "user", "content": "Hello"},
+ # {
+ # "role": "assistant",
+ # "content": None,
+ # "tool_calls": [{"id": "123", "type": "function", "function": {"name": "test", "arguments": "{}"}}],
+ # },
+ # ]
+
+ # formatted_messages = model.format_request_messages(messages, None)
+
+ # assert len(formatted_messages) == 2
+ # assert formatted_messages[0]["role"] == "user"
+ # assert formatted_messages[1]["role"] == "assistant"
+ # assert "content" not in formatted_messages[1]
+ # assert "tool_calls" in formatted_messages[1]
+
+ # def test_format_request(self, model, messages, tool_specs, system_prompt):
+ # """Test formatting a request with all parameters."""
+ # request = model.format_request(messages, tool_specs, system_prompt)
+
+ # assert request["EndpointName"] == "test-endpoint"
+ # assert request["InferenceComponentName"] == "test-component"
+ # assert request["ContentType"] == "application/json"
+ # assert request["Accept"] == "application/json"
+
+ # payload = json.loads(request["Body"])
+ # assert "messages" in payload
+ # assert len(payload["messages"]) > 0
+ # assert "tools" in payload
+ # assert len(payload["tools"]) == 1
+ # assert payload["tools"][0]["type"] == "function"
+ # assert payload["tools"][0]["function"]["name"] == "get_weather"
+ # assert payload["max_tokens"] == 1024
+ # assert payload["temperature"] == 0.7
+ # assert payload["stream"] is True
+
+ # def test_format_request_without_tools(self, model, messages, system_prompt):
+ # """Test formatting a request without tools."""
+ # request = model.format_request(messages, None, system_prompt)
+
+ # payload = json.loads(request["Body"])
+ # assert "tools" in payload
+ # assert payload["tools"] == []
+
+ @pytest.mark.asyncio
+ async def test_stream_with_streaming_enabled(self, sagemaker_client, model, messages):
+ """Test streaming response with streaming enabled."""
+ # Mock the response from SageMaker
+ mock_response = {
+ "Body": [
+ {
+ "PayloadPart": {
+ "Bytes": json.dumps(
+ {
+ "choices": [
+ {
+ "delta": {"content": "Paris is the capital of France."},
+ "finish_reason": None,
+ }
+ ]
+ }
+ ).encode("utf-8")
+ }
+ },
+ {
+ "PayloadPart": {
+ "Bytes": json.dumps(
+ {
+ "choices": [
+ {
+ "delta": {"content": " It is known for the Eiffel Tower."},
+ "finish_reason": "stop",
+ }
+ ]
+ }
+ ).encode("utf-8")
+ }
+ },
+ ]
+ }
+ sagemaker_client.invoke_endpoint_with_response_stream.return_value = mock_response
+
+ response = [chunk async for chunk in model.stream(messages)]
+
+ assert len(response) >= 5
+ assert response[0] == {"messageStart": {"role": "assistant"}}
+
+ # Find content events
+ content_start = next((e for e in response if "contentBlockStart" in e), None)
+ content_delta = next((e for e in response if "contentBlockDelta" in e), None)
+ content_stop = next((e for e in response if "contentBlockStop" in e), None)
+ message_stop = next((e for e in response if "messageStop" in e), None)
+
+ assert content_start is not None
+ assert content_delta is not None
+ assert content_stop is not None
+ assert message_stop is not None
+ assert message_stop["messageStop"]["stopReason"] == "end_turn"
+
+ sagemaker_client.invoke_endpoint_with_response_stream.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_stream_with_tool_calls(self, sagemaker_client, model, messages):
+ """Test streaming response with tool calls."""
+ # Mock the response from SageMaker with tool calls
+ mock_response = {
+ "Body": [
+ {
+ "PayloadPart": {
+ "Bytes": json.dumps(
+ {
+ "choices": [
+ {
+ "delta": {
+ "content": None,
+ "tool_calls": [
+ {
+ "index": 0,
+ "id": "tool123",
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "arguments": '{"location": "Paris"}',
+ },
+ }
+ ],
+ },
+ "finish_reason": "tool_calls",
+ }
+ ]
+ }
+ ).encode("utf-8")
+ }
+ }
+ ]
+ }
+ sagemaker_client.invoke_endpoint_with_response_stream.return_value = mock_response
+
+ response = [chunk async for chunk in model.stream(messages)]
+
+ # Verify the response contains tool call events
+ assert len(response) >= 4
+ assert response[0] == {"messageStart": {"role": "assistant"}}
+
+ message_stop = next((e for e in response if "messageStop" in e), None)
+ assert message_stop is not None
+ assert message_stop["messageStop"]["stopReason"] == "tool_use"
+
+ # Find tool call events
+ tool_start = next(
+ (
+ e
+ for e in response
+ if "contentBlockStart" in e and e.get("contentBlockStart", {}).get("start", {}).get("toolUse")
+ ),
+ None,
+ )
+ tool_delta = next(
+ (
+ e
+ for e in response
+ if "contentBlockDelta" in e and e.get("contentBlockDelta", {}).get("delta", {}).get("toolUse")
+ ),
+ None,
+ )
+ tool_stop = next((e for e in response if "contentBlockStop" in e), None)
+
+ assert tool_start is not None
+ assert tool_delta is not None
+ assert tool_stop is not None
+
+ # Verify tool call data
+ tool_use_data = tool_start["contentBlockStart"]["start"]["toolUse"]
+ assert tool_use_data["toolUseId"] == "tool123"
+ assert tool_use_data["name"] == "get_weather"
+
+ @pytest.mark.asyncio
+ async def test_stream_with_partial_json(self, sagemaker_client, model, messages):
+ """Test streaming response with partial JSON chunks."""
+ # Mock the response from SageMaker with split JSON
+ mock_response = {
+ "Body": [
+ {"PayloadPart": {"Bytes": '{"choices": [{"delta": {"content": "Paris is'.encode("utf-8")}},
+ {"PayloadPart": {"Bytes": ' the capital of France."}, "finish_reason": "stop"}]}'.encode("utf-8")}},
+ ]
+ }
+ sagemaker_client.invoke_endpoint_with_response_stream.return_value = mock_response
+
+ response = [chunk async for chunk in model.stream(messages)]
+
+ assert len(response) == 5
+ assert response[0] == {"messageStart": {"role": "assistant"}}
+
+ # Find content events
+ content_start = next((e for e in response if "contentBlockStart" in e), None)
+ content_delta = next((e for e in response if "contentBlockDelta" in e), None)
+ content_stop = next((e for e in response if "contentBlockStop" in e), None)
+ message_stop = next((e for e in response if "messageStop" in e), None)
+
+ assert content_start is not None
+ assert content_delta is not None
+ assert content_stop is not None
+ assert message_stop is not None
+ assert message_stop["messageStop"]["stopReason"] == "end_turn"
+
+ # Verify content
+ text_delta = content_delta["contentBlockDelta"]["delta"]["text"]
+ assert text_delta == "Paris is the capital of France."
+
+ @pytest.mark.asyncio
+ async def test_stream_non_streaming(self, sagemaker_client, model, messages):
+ """Test non-streaming response."""
+ # Configure model for non-streaming
+ model.payload_config["stream"] = False
+
+ # Mock the response from SageMaker
+ mock_response = {"Body": unittest.mock.MagicMock()}
+ mock_response["Body"].read.return_value = json.dumps(
+ {
+ "choices": [
+ {
+ "message": {"content": "Paris is the capital of France.", "tool_calls": None},
+ "finish_reason": "stop",
+ }
+ ],
+ "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30, "prompt_tokens_details": 0},
+ }
+ ).encode("utf-8")
+
+ sagemaker_client.invoke_endpoint.return_value = mock_response
+
+ response = [chunk async for chunk in model.stream(messages)]
+
+ assert len(response) >= 6
+ assert response[0] == {"messageStart": {"role": "assistant"}}
+
+ # Find content events
+ content_start = next((e for e in response if "contentBlockStart" in e), None)
+ content_delta = next((e for e in response if "contentBlockDelta" in e), None)
+ content_stop = next((e for e in response if "contentBlockStop" in e), None)
+ message_stop = next((e for e in response if "messageStop" in e), None)
+
+ assert content_start is not None
+ assert content_delta is not None
+ assert content_stop is not None
+ assert message_stop is not None
+
+ # Verify content
+ text_delta = content_delta["contentBlockDelta"]["delta"]["text"]
+ assert text_delta == "Paris is the capital of France."
+
+ sagemaker_client.invoke_endpoint.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_stream_non_streaming_with_tool_calls(self, sagemaker_client, model, messages):
+ """Test non-streaming response with tool calls."""
+ # Configure model for non-streaming
+ model.payload_config["stream"] = False
+
+ # Mock the response from SageMaker with tool calls
+ mock_response = {"Body": unittest.mock.MagicMock()}
+ mock_response["Body"].read.return_value = json.dumps(
+ {
+ "choices": [
+ {
+ "message": {
+ "content": None,
+ "tool_calls": [
+ {
+ "id": "tool123",
+ "type": "function",
+ "function": {"name": "get_weather", "arguments": '{"location": "Paris"}'},
+ }
+ ],
+ },
+ "finish_reason": "tool_calls",
+ }
+ ],
+ "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30, "prompt_tokens_details": 0},
+ }
+ ).encode("utf-8")
+
+ sagemaker_client.invoke_endpoint.return_value = mock_response
+
+ response = [chunk async for chunk in model.stream(messages)]
+
+ # Verify basic structure
+ assert len(response) >= 6
+ assert response[0] == {"messageStart": {"role": "assistant"}}
+
+ # Find tool call events
+ tool_start = next(
+ (
+ e
+ for e in response
+ if "contentBlockStart" in e and e.get("contentBlockStart", {}).get("start", {}).get("toolUse")
+ ),
+ None,
+ )
+ tool_delta = next(
+ (
+ e
+ for e in response
+ if "contentBlockDelta" in e and e.get("contentBlockDelta", {}).get("delta", {}).get("toolUse")
+ ),
+ None,
+ )
+ tool_stop = next((e for e in response if "contentBlockStop" in e), None)
+ message_stop = next((e for e in response if "messageStop" in e), None)
+
+ assert tool_start is not None
+ assert tool_delta is not None
+ assert tool_stop is not None
+ assert message_stop is not None
+
+ # Verify tool call data
+ tool_use_data = tool_start["contentBlockStart"]["start"]["toolUse"]
+ assert tool_use_data["toolUseId"] == "tool123"
+ assert tool_use_data["name"] == "get_weather"
+
+ # Verify metadata
+ metadata = next((e for e in response if "metadata" in e), None)
+ assert metadata is not None
+ usage_data = metadata["metadata"]["usage"]
+ assert usage_data["totalTokens"] == 30
+
+
+class TestDataClasses:
+ """Test suite for data classes."""
+
+ def test_usage_metadata(self):
+ """Test UsageMetadata dataclass."""
+ usage = UsageMetadata(total_tokens=100, completion_tokens=30, prompt_tokens=70, prompt_tokens_details=5)
+
+ assert usage.total_tokens == 100
+ assert usage.completion_tokens == 30
+ assert usage.prompt_tokens == 70
+ assert usage.prompt_tokens_details == 5
+
+ def test_function_call(self):
+ """Test FunctionCall dataclass."""
+ func = FunctionCall(name="get_weather", arguments='{"location": "Paris"}')
+
+ assert func.name == "get_weather"
+ assert func.arguments == '{"location": "Paris"}'
+
+ # Test initialization with kwargs
+ func2 = FunctionCall(**{"name": "get_time", "arguments": '{"timezone": "UTC"}'})
+
+ assert func2.name == "get_time"
+ assert func2.arguments == '{"timezone": "UTC"}'
+
+ def test_tool_call(self):
+ """Test ToolCall dataclass."""
+ # Create a tool call using kwargs directly
+ tool = ToolCall(
+ id="tool123", type="function", function={"name": "get_weather", "arguments": '{"location": "Paris"}'}
+ )
+
+ assert tool.id == "tool123"
+ assert tool.type == "function"
+ assert tool.function.name == "get_weather"
+ assert tool.function.arguments == '{"location": "Paris"}'
+
+ # Test initialization with kwargs
+ tool2 = ToolCall(
+ **{
+ "id": "tool456",
+ "type": "function",
+ "function": {"name": "get_time", "arguments": '{"timezone": "UTC"}'},
+ }
+ )
+
+ assert tool2.id == "tool456"
+ assert tool2.type == "function"
+ assert tool2.function.name == "get_time"
+ assert tool2.function.arguments == '{"timezone": "UTC"}'
diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py
index a956cb769..77645fc73 100644
--- a/tests/strands/multiagent/a2a/test_executor.py
+++ b/tests/strands/multiagent/a2a/test_executor.py
@@ -36,7 +36,7 @@ async def mock_stream(user_input):
# Mock the task creation
mock_task = MagicMock()
mock_task.id = "test-task-id"
- mock_task.contextId = "test-context-id"
+ mock_task.context_id = "test-context-id"
mock_request_context.current_task = mock_task
await executor.execute(mock_request_context, mock_event_queue)
@@ -65,7 +65,7 @@ async def mock_stream(user_input):
# Mock the task creation
mock_task = MagicMock()
mock_task.id = "test-task-id"
- mock_task.contextId = "test-context-id"
+ mock_task.context_id = "test-context-id"
mock_request_context.current_task = mock_task
await executor.execute(mock_request_context, mock_event_queue)
@@ -95,7 +95,7 @@ async def mock_stream(user_input):
# Mock the task creation
mock_task = MagicMock()
mock_task.id = "test-task-id"
- mock_task.contextId = "test-context-id"
+ mock_task.context_id = "test-context-id"
mock_request_context.current_task = mock_task
await executor.execute(mock_request_context, mock_event_queue)
@@ -125,7 +125,7 @@ async def mock_stream(user_input):
# Mock the task creation
mock_task = MagicMock()
mock_task.id = "test-task-id"
- mock_task.contextId = "test-context-id"
+ mock_task.context_id = "test-context-id"
mock_request_context.current_task = mock_task
await executor.execute(mock_request_context, mock_event_queue)
@@ -156,7 +156,7 @@ async def mock_stream(user_input):
mock_request_context.current_task = None
with patch("strands.multiagent.a2a.executor.new_task") as mock_new_task:
- mock_new_task.return_value = MagicMock(id="new-task-id", contextId="new-context-id")
+ mock_new_task.return_value = MagicMock(id="new-task-id", context_id="new-context-id")
await executor.execute(mock_request_context, mock_event_queue)
@@ -180,7 +180,7 @@ async def test_execute_streaming_mode_handles_agent_exception(
# Mock the task creation
mock_task = MagicMock()
mock_task.id = "test-task-id"
- mock_task.contextId = "test-context-id"
+ mock_task.context_id = "test-context-id"
mock_request_context.current_task = mock_task
with pytest.raises(ServerError):
@@ -210,7 +210,7 @@ async def test_handle_agent_result_with_none_result(mock_strands_agent, mock_req
# Mock the task creation
mock_task = MagicMock()
mock_task.id = "test-task-id"
- mock_task.contextId = "test-context-id"
+ mock_task.context_id = "test-context-id"
mock_request_context.current_task = mock_task
# Mock TaskUpdater
@@ -235,7 +235,7 @@ async def test_handle_agent_result_with_result_but_no_message(
# Mock the task creation
mock_task = MagicMock()
mock_task.id = "test-task-id"
- mock_task.contextId = "test-context-id"
+ mock_task.context_id = "test-context-id"
mock_request_context.current_task = mock_task
# Mock TaskUpdater
diff --git a/tests/strands/multiagent/a2a/test_server.py b/tests/strands/multiagent/a2a/test_server.py
index 74f470741..a3b47581c 100644
--- a/tests/strands/multiagent/a2a/test_server.py
+++ b/tests/strands/multiagent/a2a/test_server.py
@@ -87,8 +87,8 @@ def test_public_agent_card(mock_strands_agent):
assert card.description == "A test agent for unit testing"
assert card.url == "http://0.0.0.0:9000/"
assert card.version == "0.0.1"
- assert card.defaultInputModes == ["text"]
- assert card.defaultOutputModes == ["text"]
+ assert card.default_input_modes == ["text"]
+ assert card.default_output_modes == ["text"]
assert card.skills == []
assert card.capabilities == a2a_agent.capabilities
@@ -509,3 +509,346 @@ def test_serve_handles_general_exception(mock_run, mock_strands_agent, caplog):
assert "Strands A2A server encountered exception" in caplog.text
assert "Strands A2A server has shutdown" in caplog.text
+
+
+def test_initialization_with_http_url_no_path(mock_strands_agent):
+ """Test initialization with http_url containing no path."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ a2a_agent = A2AServer(
+ mock_strands_agent, host="0.0.0.0", port=8080, http_url="http://my-alb.amazonaws.com", skills=[]
+ )
+
+ assert a2a_agent.host == "0.0.0.0"
+ assert a2a_agent.port == 8080
+ assert a2a_agent.http_url == "http://my-alb.amazonaws.com/"
+ assert a2a_agent.public_base_url == "http://my-alb.amazonaws.com"
+ assert a2a_agent.mount_path == ""
+
+
+def test_initialization_with_http_url_with_path(mock_strands_agent):
+ """Test initialization with http_url containing a path for mounting."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ a2a_agent = A2AServer(
+ mock_strands_agent, host="0.0.0.0", port=8080, http_url="http://my-alb.amazonaws.com/agent1", skills=[]
+ )
+
+ assert a2a_agent.host == "0.0.0.0"
+ assert a2a_agent.port == 8080
+ assert a2a_agent.http_url == "http://my-alb.amazonaws.com/agent1/"
+ assert a2a_agent.public_base_url == "http://my-alb.amazonaws.com"
+ assert a2a_agent.mount_path == "/agent1"
+
+
+def test_initialization_with_https_url(mock_strands_agent):
+ """Test initialization with HTTPS URL."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ a2a_agent = A2AServer(mock_strands_agent, http_url="https://my-alb.amazonaws.com/secure-agent", skills=[])
+
+ assert a2a_agent.http_url == "https://my-alb.amazonaws.com/secure-agent/"
+ assert a2a_agent.public_base_url == "https://my-alb.amazonaws.com"
+ assert a2a_agent.mount_path == "/secure-agent"
+
+
+def test_initialization_with_http_url_with_port(mock_strands_agent):
+ """Test initialization with http_url containing explicit port."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ a2a_agent = A2AServer(mock_strands_agent, http_url="http://my-server.com:8080/api/agent", skills=[])
+
+ assert a2a_agent.http_url == "http://my-server.com:8080/api/agent/"
+ assert a2a_agent.public_base_url == "http://my-server.com:8080"
+ assert a2a_agent.mount_path == "/api/agent"
+
+
+def test_parse_public_url_method(mock_strands_agent):
+ """Test the _parse_public_url method directly."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+ a2a_agent = A2AServer(mock_strands_agent, skills=[])
+
+ # Test various URL formats
+ base_url, mount_path = a2a_agent._parse_public_url("http://example.com/path")
+ assert base_url == "http://example.com"
+ assert mount_path == "/path"
+
+ base_url, mount_path = a2a_agent._parse_public_url("https://example.com:443/deep/path")
+ assert base_url == "https://example.com:443"
+ assert mount_path == "/deep/path"
+
+ base_url, mount_path = a2a_agent._parse_public_url("http://example.com/")
+ assert base_url == "http://example.com"
+ assert mount_path == ""
+
+ base_url, mount_path = a2a_agent._parse_public_url("http://example.com")
+ assert base_url == "http://example.com"
+ assert mount_path == ""
+
+
+def test_public_agent_card_with_http_url(mock_strands_agent):
+ """Test that public_agent_card uses the http_url when provided."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ a2a_agent = A2AServer(mock_strands_agent, http_url="https://my-alb.amazonaws.com/agent1", skills=[])
+
+ card = a2a_agent.public_agent_card
+
+ assert isinstance(card, AgentCard)
+ assert card.url == "https://my-alb.amazonaws.com/agent1/"
+ assert card.name == "Test Agent"
+ assert card.description == "A test agent for unit testing"
+
+
+def test_to_starlette_app_with_mounting(mock_strands_agent):
+ """Test that to_starlette_app creates mounted app when mount_path exists."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ a2a_agent = A2AServer(mock_strands_agent, http_url="http://example.com/agent1", skills=[])
+
+ app = a2a_agent.to_starlette_app()
+
+ assert isinstance(app, Starlette)
+
+
+def test_to_starlette_app_without_mounting(mock_strands_agent):
+ """Test that to_starlette_app creates regular app when no mount_path."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ a2a_agent = A2AServer(mock_strands_agent, http_url="http://example.com", skills=[])
+
+ app = a2a_agent.to_starlette_app()
+
+ assert isinstance(app, Starlette)
+
+
+def test_to_fastapi_app_with_mounting(mock_strands_agent):
+ """Test that to_fastapi_app creates mounted app when mount_path exists."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ a2a_agent = A2AServer(mock_strands_agent, http_url="http://example.com/agent1", skills=[])
+
+ app = a2a_agent.to_fastapi_app()
+
+ assert isinstance(app, FastAPI)
+
+
+def test_to_fastapi_app_without_mounting(mock_strands_agent):
+ """Test that to_fastapi_app creates regular app when no mount_path."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ a2a_agent = A2AServer(mock_strands_agent, http_url="http://example.com", skills=[])
+
+ app = a2a_agent.to_fastapi_app()
+
+ assert isinstance(app, FastAPI)
+
+
+def test_backwards_compatibility_without_http_url(mock_strands_agent):
+ """Test that the old behavior is preserved when http_url is not provided."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ a2a_agent = A2AServer(mock_strands_agent, host="localhost", port=9000, skills=[])
+
+ # Should behave exactly like before
+ assert a2a_agent.host == "localhost"
+ assert a2a_agent.port == 9000
+ assert a2a_agent.http_url == "http://localhost:9000/"
+ assert a2a_agent.public_base_url == "http://localhost:9000"
+ assert a2a_agent.mount_path == ""
+
+ # Agent card should use the traditional URL
+ card = a2a_agent.public_agent_card
+ assert card.url == "http://localhost:9000/"
+
+
+def test_mount_path_logging(mock_strands_agent, caplog):
+ """Test that mounting logs the correct message."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ a2a_agent = A2AServer(mock_strands_agent, http_url="http://example.com/test-agent", skills=[])
+
+ # Test Starlette app mounting logs
+ caplog.clear()
+ a2a_agent.to_starlette_app()
+ assert "Mounting A2A server at path: /test-agent" in caplog.text
+
+ # Test FastAPI app mounting logs
+ caplog.clear()
+ a2a_agent.to_fastapi_app()
+ assert "Mounting A2A server at path: /test-agent" in caplog.text
+
+
+def test_http_url_trailing_slash_handling(mock_strands_agent):
+ """Test that trailing slashes in http_url are handled correctly."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ # Test with trailing slash
+ a2a_agent1 = A2AServer(mock_strands_agent, http_url="http://example.com/agent1/", skills=[])
+
+ # Test without trailing slash
+ a2a_agent2 = A2AServer(mock_strands_agent, http_url="http://example.com/agent1", skills=[])
+
+ # Both should result in the same normalized URL
+ assert a2a_agent1.http_url == "http://example.com/agent1/"
+ assert a2a_agent2.http_url == "http://example.com/agent1/"
+ assert a2a_agent1.mount_path == "/agent1"
+ assert a2a_agent2.mount_path == "/agent1"
+
+
+def test_serve_at_root_default_behavior(mock_strands_agent):
+ """Test default behavior extracts mount path from http_url."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ server = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", skills=[])
+
+ assert server.mount_path == "/agent1"
+ assert server.http_url == "http://my-alb.com/agent1/"
+
+
+def test_serve_at_root_overrides_mounting(mock_strands_agent):
+ """Test serve_at_root=True overrides automatic path mounting."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ server = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", serve_at_root=True, skills=[])
+
+ assert server.mount_path == "" # Should be empty despite path in URL
+ assert server.http_url == "http://my-alb.com/agent1/" # Public URL unchanged
+
+
+def test_serve_at_root_with_no_path(mock_strands_agent):
+ """Test serve_at_root=True when no path in URL (redundant but valid)."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ server = A2AServer(mock_strands_agent, host="localhost", port=8080, serve_at_root=True, skills=[])
+
+ assert server.mount_path == ""
+ assert server.http_url == "http://localhost:8080/"
+
+
+def test_serve_at_root_complex_path(mock_strands_agent):
+ """Test serve_at_root=True with complex nested paths."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ server = A2AServer(
+ mock_strands_agent, http_url="http://api.example.com/v1/agents/my-agent", serve_at_root=True, skills=[]
+ )
+
+ assert server.mount_path == ""
+ assert server.http_url == "http://api.example.com/v1/agents/my-agent/"
+
+
+def test_serve_at_root_fastapi_mounting_behavior(mock_strands_agent):
+ """Test FastAPI mounting behavior with serve_at_root."""
+ from fastapi.testclient import TestClient
+
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ # Normal mounting
+ server_mounted = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", skills=[])
+ app_mounted = server_mounted.to_fastapi_app()
+ client_mounted = TestClient(app_mounted)
+
+ # Should work at mounted path
+ response = client_mounted.get("/agent1/.well-known/agent.json")
+ assert response.status_code == 200
+
+ # Should not work at root
+ response = client_mounted.get("/.well-known/agent.json")
+ assert response.status_code == 404
+
+
+def test_serve_at_root_fastapi_root_behavior(mock_strands_agent):
+ """Test FastAPI serve_at_root behavior."""
+ from fastapi.testclient import TestClient
+
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ # Serve at root
+ server_root = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", serve_at_root=True, skills=[])
+ app_root = server_root.to_fastapi_app()
+ client_root = TestClient(app_root)
+
+ # Should work at root
+ response = client_root.get("/.well-known/agent.json")
+ assert response.status_code == 200
+
+ # Should not work at mounted path (since we're serving at root)
+ response = client_root.get("/agent1/.well-known/agent.json")
+ assert response.status_code == 404
+
+
+def test_serve_at_root_starlette_behavior(mock_strands_agent):
+ """Test Starlette serve_at_root behavior."""
+ from starlette.testclient import TestClient
+
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ # Normal mounting
+ server_mounted = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", skills=[])
+ app_mounted = server_mounted.to_starlette_app()
+ client_mounted = TestClient(app_mounted)
+
+ # Should work at mounted path
+ response = client_mounted.get("/agent1/.well-known/agent.json")
+ assert response.status_code == 200
+
+ # Serve at root
+ server_root = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", serve_at_root=True, skills=[])
+ app_root = server_root.to_starlette_app()
+ client_root = TestClient(app_root)
+
+ # Should work at root
+ response = client_root.get("/.well-known/agent.json")
+ assert response.status_code == 200
+
+
+def test_serve_at_root_alb_scenarios(mock_strands_agent):
+ """Test common ALB deployment scenarios."""
+ from fastapi.testclient import TestClient
+
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ # ALB with path preservation
+ server_preserved = A2AServer(mock_strands_agent, http_url="http://my-alb.amazonaws.com/agent1", skills=[])
+ app_preserved = server_preserved.to_fastapi_app()
+ client_preserved = TestClient(app_preserved)
+
+ # Container receives /agent1/.well-known/agent.json
+ response = client_preserved.get("/agent1/.well-known/agent.json")
+ assert response.status_code == 200
+ agent_data = response.json()
+ assert agent_data["url"] == "http://my-alb.amazonaws.com/agent1/"
+
+ # ALB with path stripping
+ server_stripped = A2AServer(
+ mock_strands_agent, http_url="http://my-alb.amazonaws.com/agent1", serve_at_root=True, skills=[]
+ )
+ app_stripped = server_stripped.to_fastapi_app()
+ client_stripped = TestClient(app_stripped)
+
+ # Container receives /.well-known/agent.json (path stripped by ALB)
+ response = client_stripped.get("/.well-known/agent.json")
+ assert response.status_code == 200
+ agent_data = response.json()
+ assert agent_data["url"] == "http://my-alb.amazonaws.com/agent1/"
+
+
+def test_serve_at_root_edge_cases(mock_strands_agent):
+ """Test edge cases for serve_at_root parameter."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ # Root path in URL
+ server1 = A2AServer(mock_strands_agent, http_url="http://example.com/", skills=[])
+ assert server1.mount_path == ""
+
+ # serve_at_root should be redundant but not cause issues
+ server2 = A2AServer(mock_strands_agent, http_url="http://example.com/", serve_at_root=True, skills=[])
+ assert server2.mount_path == ""
+
+ # Multiple nested paths
+ server3 = A2AServer(
+ mock_strands_agent, http_url="http://api.example.com/v1/agents/team1/agent1", serve_at_root=True, skills=[]
+ )
+ assert server3.mount_path == ""
+ assert server3.http_url == "http://api.example.com/v1/agents/team1/agent1/"
diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py
index 6a2fdd00c..bd88382cd 100644
--- a/tests/strands/tools/mcp/test_mcp_client.py
+++ b/tests/strands/tools/mcp/test_mcp_client.py
@@ -4,10 +4,12 @@
import pytest
from mcp import ListToolsResult
from mcp.types import CallToolResult as MCPCallToolResult
+from mcp.types import GetPromptResult, ListPromptsResult, Prompt, PromptMessage
from mcp.types import TextContent as MCPTextContent
from mcp.types import Tool as MCPTool
from strands.tools.mcp import MCPClient
+from strands.tools.mcp.mcp_types import MCPToolResult
from strands.types.exceptions import MCPClientInitializationError
@@ -129,6 +131,8 @@ def test_call_tool_sync_status(mock_transport, mock_session, is_error, expected_
assert result["toolUseId"] == "test-123"
assert len(result["content"]) == 1
assert result["content"][0]["text"] == "Test message"
+ # No structured content should be present when not provided by MCP
+ assert result.get("structuredContent") is None
def test_call_tool_sync_session_not_active():
@@ -139,6 +143,31 @@ def test_call_tool_sync_session_not_active():
client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})
+def test_call_tool_sync_with_structured_content(mock_transport, mock_session):
+ """Test that call_tool_sync correctly handles structured content."""
+ mock_content = MCPTextContent(type="text", text="Test message")
+ structured_content = {"result": 42, "status": "completed"}
+ mock_session.call_tool.return_value = MCPCallToolResult(
+ isError=False, content=[mock_content], structuredContent=structured_content
+ )
+
+ with MCPClient(mock_transport["transport_callable"]) as client:
+ result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})
+
+ mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None)
+
+ assert result["status"] == "success"
+ assert result["toolUseId"] == "test-123"
+ # Content should only contain the text content, not the structured content
+ assert len(result["content"]) == 1
+ assert result["content"][0]["text"] == "Test message"
+ # Structured content should be in its own field
+ assert "structuredContent" in result
+ assert result["structuredContent"] == structured_content
+ assert result["structuredContent"]["result"] == 42
+ assert result["structuredContent"]["status"] == "completed"
+
+
def test_call_tool_sync_exception(mock_transport, mock_session):
"""Test that call_tool_sync correctly handles exceptions."""
mock_session.call_tool.side_effect = Exception("Test exception")
@@ -312,6 +341,45 @@ def test_enter_with_initialization_exception(mock_transport):
client.start()
+def test_mcp_tool_result_type():
+ """Test that MCPToolResult extends ToolResult correctly."""
+ # Test basic ToolResult functionality
+ result = MCPToolResult(status="success", toolUseId="test-123", content=[{"text": "Test message"}])
+
+ assert result["status"] == "success"
+ assert result["toolUseId"] == "test-123"
+ assert result["content"][0]["text"] == "Test message"
+
+ # Test that structuredContent is optional
+ assert "structuredContent" not in result or result.get("structuredContent") is None
+
+ # Test with structuredContent
+ result_with_structured = MCPToolResult(
+ status="success", toolUseId="test-456", content=[{"text": "Test message"}], structuredContent={"key": "value"}
+ )
+
+ assert result_with_structured["structuredContent"] == {"key": "value"}
+
+
+def test_call_tool_sync_without_structured_content(mock_transport, mock_session):
+ """Test that call_tool_sync works correctly when no structured content is provided."""
+ mock_content = MCPTextContent(type="text", text="Test message")
+ mock_session.call_tool.return_value = MCPCallToolResult(
+ isError=False,
+ content=[mock_content], # No structuredContent
+ )
+
+ with MCPClient(mock_transport["transport_callable"]) as client:
+ result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})
+
+ assert result["status"] == "success"
+ assert result["toolUseId"] == "test-123"
+ assert len(result["content"]) == 1
+ assert result["content"][0]["text"] == "Test message"
+ # structuredContent should be None when not provided by MCP
+ assert result.get("structuredContent") is None
+
+
def test_exception_when_future_not_running():
"""Test exception handling when the future is not running."""
# Create a client.with a mock transport
@@ -337,3 +405,64 @@ def test_exception_when_future_not_running():
# Verify that set_exception was not called since the future was not running
mock_future.set_exception.assert_not_called()
+
+
+# Prompt Tests - Sync Methods
+
+
+def test_list_prompts_sync(mock_transport, mock_session):
+ """Test that list_prompts_sync correctly retrieves prompts."""
+ mock_prompt = Prompt(name="test_prompt", description="A test prompt", id="prompt_1")
+ mock_session.list_prompts.return_value = ListPromptsResult(prompts=[mock_prompt])
+
+ with MCPClient(mock_transport["transport_callable"]) as client:
+ result = client.list_prompts_sync()
+
+ mock_session.list_prompts.assert_called_once_with(cursor=None)
+ assert len(result.prompts) == 1
+ assert result.prompts[0].name == "test_prompt"
+ assert result.nextCursor is None
+
+
+def test_list_prompts_sync_with_pagination_token(mock_transport, mock_session):
+ """Test that list_prompts_sync correctly passes pagination token and returns next cursor."""
+ mock_prompt = Prompt(name="test_prompt", description="A test prompt", id="prompt_1")
+ mock_session.list_prompts.return_value = ListPromptsResult(prompts=[mock_prompt], nextCursor="next_page_token")
+
+ with MCPClient(mock_transport["transport_callable"]) as client:
+ result = client.list_prompts_sync(pagination_token="current_page_token")
+
+ mock_session.list_prompts.assert_called_once_with(cursor="current_page_token")
+ assert len(result.prompts) == 1
+ assert result.prompts[0].name == "test_prompt"
+ assert result.nextCursor == "next_page_token"
+
+
+def test_list_prompts_sync_session_not_active():
+ """Test that list_prompts_sync raises an error when session is not active."""
+ client = MCPClient(MagicMock())
+
+ with pytest.raises(MCPClientInitializationError, match="client session is not running"):
+ client.list_prompts_sync()
+
+
+def test_get_prompt_sync(mock_transport, mock_session):
+ """Test that get_prompt_sync correctly retrieves a prompt."""
+ mock_message = PromptMessage(role="user", content=MCPTextContent(type="text", text="This is a test prompt"))
+ mock_session.get_prompt.return_value = GetPromptResult(messages=[mock_message])
+
+ with MCPClient(mock_transport["transport_callable"]) as client:
+ result = client.get_prompt_sync("test_prompt_id", {"key": "value"})
+
+ mock_session.get_prompt.assert_called_once_with("test_prompt_id", arguments={"key": "value"})
+ assert len(result.messages) == 1
+ assert result.messages[0].role == "user"
+ assert result.messages[0].content.text == "This is a test prompt"
+
+
+def test_get_prompt_sync_session_not_active():
+ """Test that get_prompt_sync raises an error when session is not active."""
+ client = MCPClient(MagicMock())
+
+ with pytest.raises(MCPClientInitializationError, match="client session is not running"):
+ client.get_prompt_sync("test_prompt_id", {})
diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py
index ebcba3fb1..66494c987 100644
--- a/tests/strands/tools/test_registry.py
+++ b/tests/strands/tools/test_registry.py
@@ -93,3 +93,30 @@ def tool_function_4(d):
assert len(tools) == 2
assert all(isinstance(tool, DecoratedFunctionTool) for tool in tools)
+
+
+def test_process_tools_flattens_lists_and_tuples_and_sets():
+ def function() -> str:
+ return "done"
+
+ tool_a = tool(name="tool_a")(function)
+ tool_b = tool(name="tool_b")(function)
+ tool_c = tool(name="tool_c")(function)
+ tool_d = tool(name="tool_d")(function)
+ tool_e = tool(name="tool_e")(function)
+ tool_f = tool(name="tool_f")(function)
+
+ registry = ToolRegistry()
+
+ all_tools = [tool_a, (tool_b, tool_c), [{tool_d, tool_e}, [tool_f]]]
+
+ tru_tool_names = sorted(registry.process_tools(all_tools))
+ exp_tool_names = [
+ "tool_a",
+ "tool_b",
+ "tool_c",
+ "tool_d",
+ "tool_e",
+ "tool_f",
+ ]
+ assert tru_tool_names == exp_tool_names
diff --git a/tests_integ/conftest.py b/tests_integ/conftest.py
index f83f0e299..61c2bf9a1 100644
--- a/tests_integ/conftest.py
+++ b/tests_integ/conftest.py
@@ -1,5 +1,17 @@
+import json
+import logging
+import os
+
+import boto3
import pytest
+logger = logging.getLogger(__name__)
+
+
+def pytest_sessionstart(session):
+ _load_api_keys_from_secrets_manager()
+
+
## Data
@@ -28,3 +40,43 @@ async def alist(items):
return [item async for item in items]
return alist
+
+
+## Models
+
+
+def _load_api_keys_from_secrets_manager():
+ """Load API keys as environment variables from AWS Secrets Manager."""
+ session = boto3.session.Session()
+ client = session.client(service_name="secretsmanager")
+ if "STRANDS_TEST_API_KEYS_SECRET_NAME" in os.environ:
+ try:
+ secret_name = os.getenv("STRANDS_TEST_API_KEYS_SECRET_NAME")
+ response = client.get_secret_value(SecretId=secret_name)
+
+ if "SecretString" in response:
+ secret = json.loads(response["SecretString"])
+ for key, value in secret.items():
+ os.environ[f"{key.upper()}_API_KEY"] = str(value)
+
+ except Exception as e:
+ logger.warning("Error retrieving secret", e)
+
+ """
+ Validate that required environment variables are set when running in GitHub Actions.
+ This prevents tests from being unintentionally skipped due to missing credentials.
+ """
+ if os.environ.get("GITHUB_ACTIONS") != "true":
+ logger.warning("Tests running outside GitHub Actions, skipping required provider validation")
+ return
+
+ required_providers = {
+ "ANTHROPIC_API_KEY",
+ "COHERE_API_KEY",
+ "MISTRAL_API_KEY",
+ "OPENAI_API_KEY",
+ "WRITER_API_KEY",
+ }
+ for provider in required_providers:
+ if provider not in os.environ or not os.environ[provider]:
+ raise ValueError(f"Missing required environment variables for {provider}")
diff --git a/tests_integ/echo_server.py b/tests_integ/echo_server.py
index d309607a8..52223792c 100644
--- a/tests_integ/echo_server.py
+++ b/tests_integ/echo_server.py
@@ -2,7 +2,7 @@
Echo Server for MCP Integration Testing
This module implements a simple echo server using the Model Context Protocol (MCP).
-It provides a basic tool that echoes back any input string, which is useful for
+It provides basic tools that echo back input strings and structured content, which is useful for
testing the MCP communication flow and validating that messages are properly
transmitted between the client and server.
@@ -15,6 +15,8 @@
$ python echo_server.py
"""
+from typing import Any, Dict
+
from mcp.server import FastMCP
@@ -22,16 +24,22 @@ def start_echo_server():
"""
Initialize and start the MCP echo server.
- Creates a FastMCP server instance with a single 'echo' tool that returns
- any input string back to the caller. The server uses stdio transport
+ Creates a FastMCP server instance with tools that return
+ input strings and structured content back to the caller. The server uses stdio transport
for communication.
+
"""
mcp = FastMCP("Echo Server")
- @mcp.tool(description="Echos response back to the user")
+ @mcp.tool(description="Echos response back to the user", structured_output=False)
def echo(to_echo: str) -> str:
return to_echo
+ # FastMCP automatically constructs structured output schema from method signature
+ @mcp.tool(description="Echos response back with structured content", structured_output=True)
+ def echo_with_structured_content(to_echo: str) -> Dict[str, Any]:
+ return {"echoed": to_echo}
+
mcp.run(transport="stdio")
diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py
index 543f58480..d2ac148d3 100644
--- a/tests_integ/models/providers.py
+++ b/tests_integ/models/providers.py
@@ -72,11 +72,11 @@ def __init__(self):
bedrock = ProviderInfo(id="bedrock", factory=lambda: BedrockModel())
cohere = ProviderInfo(
id="cohere",
- environment_variable="CO_API_KEY",
+ environment_variable="COHERE_API_KEY",
factory=lambda: OpenAIModel(
client_args={
"base_url": "https://api.cohere.com/compatibility/v1",
- "api_key": os.getenv("CO_API_KEY"),
+ "api_key": os.getenv("COHERE_API_KEY"),
},
model_id="command-a-03-2025",
params={"stream_options": None},
diff --git a/tests_integ/models/conformance.py b/tests_integ/models/test_conformance.py
similarity index 81%
rename from tests_integ/models/conformance.py
rename to tests_integ/models/test_conformance.py
index 262e41e42..d9875bc07 100644
--- a/tests_integ/models/conformance.py
+++ b/tests_integ/models/test_conformance.py
@@ -1,6 +1,6 @@
import pytest
-from strands.types.models import Model
+from strands.models import Model
from tests_integ.models.providers import ProviderInfo, all_providers
@@ -9,7 +9,7 @@ def get_models():
pytest.param(
provider_info,
id=provider_info.id, # Adds the provider name to the test name
- marks=[provider_info.mark], # ignores tests that don't have the requirements
+ marks=provider_info.mark, # ignores tests that don't have the requirements
)
for provider_info in all_providers
]
diff --git a/tests_integ/models/test_model_anthropic.py b/tests_integ/models/test_model_anthropic.py
index 2ee5e7f23..62a95d06d 100644
--- a/tests_integ/models/test_model_anthropic.py
+++ b/tests_integ/models/test_model_anthropic.py
@@ -6,10 +6,17 @@
import strands
from strands import Agent
from strands.models.anthropic import AnthropicModel
-from tests_integ.models import providers
-# these tests only run if we have the anthropic api key
-pytestmark = providers.anthropic.mark
+"""
+These tests only run if we have the anthropic api key
+
+Because of infrequent burst usage, Anthropic tests are unreliable, failing tests with 529s.
+{'type': 'error', 'error': {'details': None, 'type': 'overloaded_error', 'message': 'Overloaded'}}
+https://docs.anthropic.com/en/api/errors#http-errors
+"""
+pytestmark = pytest.skip(
+ "Because of infrequent burst usage, Anthropic tests are unreliable, failing with 529s", allow_module_level=True
+)
@pytest.fixture
diff --git a/tests_integ/models/test_model_cohere.py b/tests_integ/models/test_model_cohere.py
index 996b0f326..33fb1a8c6 100644
--- a/tests_integ/models/test_model_cohere.py
+++ b/tests_integ/models/test_model_cohere.py
@@ -16,7 +16,7 @@ def model():
return OpenAIModel(
client_args={
"base_url": "https://api.cohere.com/compatibility/v1",
- "api_key": os.getenv("CO_API_KEY"),
+ "api_key": os.getenv("COHERE_API_KEY"),
},
model_id="command-a-03-2025",
params={"stream_options": None},
diff --git a/tests_integ/models/test_model_sagemaker.py b/tests_integ/models/test_model_sagemaker.py
new file mode 100644
index 000000000..62362e299
--- /dev/null
+++ b/tests_integ/models/test_model_sagemaker.py
@@ -0,0 +1,76 @@
+import os
+
+import pytest
+
+import strands
+from strands import Agent
+from strands.models.sagemaker import SageMakerAIModel
+
+
+@pytest.fixture
+def model():
+ endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig(
+ endpoint_name=os.getenv("SAGEMAKER_ENDPOINT_NAME", ""), region_name="us-east-1"
+ )
+ payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024, temperature=0.7, stream=False)
+ return SageMakerAIModel(endpoint_config=endpoint_config, payload_config=payload_config)
+
+
+@pytest.fixture
+def tools():
+ @strands.tool
+ def tool_time(location: str) -> str:
+ """Get the current time for a location."""
+ return f"The time in {location} is 12:00 PM"
+
+ @strands.tool
+ def tool_weather(location: str) -> str:
+ """Get the current weather for a location."""
+ return f"The weather in {location} is sunny"
+
+ return [tool_time, tool_weather]
+
+
+@pytest.fixture
+def system_prompt():
+ return "You are a helpful assistant that provides concise answers."
+
+
+@pytest.fixture
+def agent(model, tools, system_prompt):
+ return Agent(model=model, tools=tools, system_prompt=system_prompt)
+
+
+@pytest.mark.skipif(
+ "SAGEMAKER_ENDPOINT_NAME" not in os.environ,
+ reason="SAGEMAKER_ENDPOINT_NAME environment variable missing",
+)
+def test_agent_with_tools(agent):
+ result = agent("What is the time and weather in New York?")
+ text = result.message["content"][0]["text"].lower()
+
+ assert "12:00" in text and "sunny" in text
+
+
+@pytest.mark.skipif(
+ "SAGEMAKER_ENDPOINT_NAME" not in os.environ,
+ reason="SAGEMAKER_ENDPOINT_NAME environment variable missing",
+)
+def test_agent_without_tools(model, system_prompt):
+ agent = Agent(model=model, system_prompt=system_prompt)
+ result = agent("Hello, how are you?")
+
+ assert result.message["content"][0]["text"]
+ assert len(result.message["content"][0]["text"]) > 0
+
+
+@pytest.mark.skipif(
+ "SAGEMAKER_ENDPOINT_NAME" not in os.environ,
+ reason="SAGEMAKER_ENDPOINT_NAME environment variable missing",
+)
+@pytest.mark.parametrize("location", ["Tokyo", "London", "Sydney"])
+def test_agent_different_locations(agent, location):
+ result = agent(f"What is the weather in {location}?")
+ text = result.message["content"][0]["text"].lower()
+
+ assert location.lower() in text and "sunny" in text
diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py
new file mode 100644
index 000000000..bf5668349
--- /dev/null
+++ b/tests_integ/test_max_tokens_reached.py
@@ -0,0 +1,48 @@
+import logging
+
+import pytest
+
+from src.strands.agent import AgentResult
+from strands import Agent, tool
+from strands.models.bedrock import BedrockModel
+from strands.types.exceptions import MaxTokensReachedException
+
+logger = logging.getLogger(__name__)
+
+
+@tool
+def story_tool(story: str) -> str:
+ """
+ Tool that writes a story that is minimum 50,000 lines long.
+ """
+ return story
+
+
+def test_max_tokens_reached():
+ """Test that MaxTokensReachedException is raised but the agent can still rerun on the second pass"""
+ model = BedrockModel(max_tokens=100)
+ agent = Agent(model=model, tools=[story_tool])
+
+ # This should raise an exception
+ with pytest.raises(MaxTokensReachedException):
+ agent("Tell me a story!")
+
+ # Validate that at least one message contains the incomplete tool use error message
+ expected_text = "tool use was incomplete due to maximum token limits being reached"
+ all_text_content = [
+ content_block["text"]
+ for message in agent.messages
+ for content_block in message.get("content", [])
+ if "text" in content_block
+ ]
+
+ assert any(expected_text in text for text in all_text_content), (
+ f"Expected to find message containing '{expected_text}' in agent messages"
+ )
+
+ # Remove tools from agent and re-run with a generic question
+ agent.tool_registry.registry = {}
+ agent.tool_registry.tool_config = {}
+
+ result: AgentResult = agent("What is 3+3")
+ assert result.stop_reason == "end_turn"
diff --git a/tests_integ/test_mcp_client.py b/tests_integ/test_mcp_client.py
index 9163f625d..3de249435 100644
--- a/tests_integ/test_mcp_client.py
+++ b/tests_integ/test_mcp_client.py
@@ -1,4 +1,5 @@
import base64
+import json
import os
import threading
import time
@@ -17,18 +18,17 @@
from strands.types.tools import ToolUse
-def start_calculator_server(transport: Literal["sse", "streamable-http"], port=int):
+def start_comprehensive_mcp_server(transport: Literal["sse", "streamable-http"], port=int):
"""
- Initialize and start an MCP calculator server for integration testing.
+ Initialize and start a comprehensive MCP server for integration testing.
- This function creates a FastMCP server instance that provides a simple
- calculator tool for performing addition operations. The server uses
- Server-Sent Events (SSE) transport for communication, making it accessible
- over HTTP.
+ This function creates a FastMCP server instance that provides tools, prompts,
+ and resources all in one server for comprehensive testing. The server uses
+ Server-Sent Events (SSE) or streamable HTTP transport for communication.
"""
from mcp.server import FastMCP
- mcp = FastMCP("Calculator Server", port=port)
+ mcp = FastMCP("Comprehensive MCP Server", port=port)
@mcp.tool(description="Calculator tool which performs calculations")
def calculator(x: int, y: int) -> int:
@@ -43,6 +43,15 @@ def generate_custom_image() -> MCPImageContent:
except Exception as e:
print("Error while generating custom image: {}".format(e))
+ # Prompts
+ @mcp.prompt(description="A greeting prompt template")
+ def greeting_prompt(name: str = "World") -> str:
+ return f"Hello, {name}! How are you today?"
+
+ @mcp.prompt(description="A math problem prompt template")
+ def math_prompt(operation: str = "addition", difficulty: str = "easy") -> str:
+ return f"Create a {difficulty} {operation} math problem and solve it step by step."
+
mcp.run(transport=transport)
@@ -57,8 +66,9 @@ def test_mcp_client():
{'role': 'assistant', 'content': [{'text': '\n\nThe result of adding 1 and 2 is 3.'}]}
""" # noqa: E501
+ # Start comprehensive server with tools, prompts, and resources
server_thread = threading.Thread(
- target=start_calculator_server, kwargs={"transport": "sse", "port": 8000}, daemon=True
+ target=start_comprehensive_mcp_server, kwargs={"transport": "sse", "port": 8000}, daemon=True
)
server_thread.start()
time.sleep(2) # wait for server to startup completely
@@ -67,8 +77,14 @@ def test_mcp_client():
stdio_mcp_client = MCPClient(
lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"]))
)
+
with sse_mcp_client, stdio_mcp_client:
- agent = Agent(tools=sse_mcp_client.list_tools_sync() + stdio_mcp_client.list_tools_sync())
+ # Test Tools functionality
+ sse_tools = sse_mcp_client.list_tools_sync()
+ stdio_tools = stdio_mcp_client.list_tools_sync()
+ all_tools = sse_tools + stdio_tools
+
+ agent = Agent(tools=all_tools)
agent("add 1 and 2, then echo the result back to me")
tool_use_content_blocks = _messages_to_content_blocks(agent.messages)
@@ -87,6 +103,61 @@ def test_mcp_client():
]
)
+ # Test Prompts functionality
+ prompts_result = sse_mcp_client.list_prompts_sync()
+ assert len(prompts_result.prompts) >= 2 # We expect at least greeting and math prompts
+
+ prompt_names = [prompt.name for prompt in prompts_result.prompts]
+ assert "greeting_prompt" in prompt_names
+ assert "math_prompt" in prompt_names
+
+ # Test get_prompt_sync with greeting prompt
+ greeting_result = sse_mcp_client.get_prompt_sync("greeting_prompt", {"name": "Alice"})
+ assert len(greeting_result.messages) > 0
+ prompt_text = greeting_result.messages[0].content.text
+ assert "Hello, Alice!" in prompt_text
+ assert "How are you today?" in prompt_text
+
+ # Test get_prompt_sync with math prompt
+ math_result = sse_mcp_client.get_prompt_sync(
+ "math_prompt", {"operation": "multiplication", "difficulty": "medium"}
+ )
+ assert len(math_result.messages) > 0
+ math_text = math_result.messages[0].content.text
+ assert "multiplication" in math_text
+ assert "medium" in math_text
+ assert "step by step" in math_text
+
+ # Test pagination support for prompts
+ prompts_with_token = sse_mcp_client.list_prompts_sync(pagination_token=None)
+ assert len(prompts_with_token.prompts) >= 0
+
+ # Test pagination support for tools (existing functionality)
+ tools_with_token = sse_mcp_client.list_tools_sync(pagination_token=None)
+ assert len(tools_with_token) >= 0
+
+ # TODO: Add resources testing when resources are implemented
+ # resources_result = sse_mcp_client.list_resources_sync()
+ # assert len(resources_result.resources) >= 0
+
+ tool_use_id = "test-structured-content-123"
+ result = stdio_mcp_client.call_tool_sync(
+ tool_use_id=tool_use_id,
+ name="echo_with_structured_content",
+ arguments={"to_echo": "STRUCTURED_DATA_TEST"},
+ )
+
+ # With the new MCPToolResult, structured content is in its own field
+ assert "structuredContent" in result
+ assert result["structuredContent"]["result"] == {"echoed": "STRUCTURED_DATA_TEST"}
+
+ # Verify the result is an MCPToolResult (at runtime it's just a dict, but type-wise it should be MCPToolResult)
+ assert result["status"] == "success"
+ assert result["toolUseId"] == tool_use_id
+
+ assert len(result["content"]) == 1
+ assert json.loads(result["content"][0]["text"]) == {"echoed": "STRUCTURED_DATA_TEST"}
+
def test_can_reuse_mcp_client():
stdio_mcp_client = MCPClient(
@@ -103,13 +174,72 @@ def test_can_reuse_mcp_client():
assert any([block["name"] == "echo" for block in tool_use_content_blocks])
+@pytest.mark.asyncio
+async def test_mcp_client_async_structured_content():
+ """Test that async MCP client calls properly handle structured content.
+
+ This test demonstrates how tools configure structured output: FastMCP automatically
+ constructs structured output schema from method signature when structured_output=True
+ is set in the @mcp.tool decorator. The return type annotation defines the structure
+ that appears in structuredContent field.
+ """
+ stdio_mcp_client = MCPClient(
+ lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"]))
+ )
+
+ with stdio_mcp_client:
+ tool_use_id = "test-async-structured-content-456"
+ result = await stdio_mcp_client.call_tool_async(
+ tool_use_id=tool_use_id,
+ name="echo_with_structured_content",
+ arguments={"to_echo": "ASYNC_STRUCTURED_TEST"},
+ )
+
+ # Verify structured content is in its own field
+ assert "structuredContent" in result
+ # "result" nesting is not part of the MCP Structured Content specification,
+ # but rather a FastMCP implementation detail
+ assert result["structuredContent"]["result"] == {"echoed": "ASYNC_STRUCTURED_TEST"}
+
+ # Verify basic MCPToolResult structure
+ assert result["status"] in ["success", "error"]
+ assert result["toolUseId"] == tool_use_id
+
+ assert len(result["content"]) == 1
+ assert json.loads(result["content"][0]["text"]) == {"echoed": "ASYNC_STRUCTURED_TEST"}
+
+
+def test_mcp_client_without_structured_content():
+ """Test that MCP client works correctly when tools don't return structured content."""
+ stdio_mcp_client = MCPClient(
+ lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"]))
+ )
+
+ with stdio_mcp_client:
+ tool_use_id = "test-no-structured-content-789"
+ result = stdio_mcp_client.call_tool_sync(
+ tool_use_id=tool_use_id,
+ name="echo", # This tool doesn't return structured content
+ arguments={"to_echo": "SIMPLE_ECHO_TEST"},
+ )
+
+ # Verify no structured content when tool doesn't provide it
+ assert result.get("structuredContent") is None
+
+ # Verify basic result structure
+ assert result["status"] == "success"
+ assert result["toolUseId"] == tool_use_id
+ assert result["content"] == [{"text": "SIMPLE_ECHO_TEST"}]
+
+
@pytest.mark.skipif(
condition=os.environ.get("GITHUB_ACTIONS") == "true",
reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue",
)
def test_streamable_http_mcp_client():
+ """Test comprehensive MCP client with streamable HTTP transport."""
server_thread = threading.Thread(
- target=start_calculator_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True
+ target=start_comprehensive_mcp_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True
)
server_thread.start()
time.sleep(2) # wait for server to startup completely
@@ -119,12 +249,22 @@ def transport_callback() -> MCPTransport:
streamable_http_client = MCPClient(transport_callback)
with streamable_http_client:
+ # Test tools
agent = Agent(tools=streamable_http_client.list_tools_sync())
agent("add 1 and 2 using a calculator")
tool_use_content_blocks = _messages_to_content_blocks(agent.messages)
assert any([block["name"] == "calculator" for block in tool_use_content_blocks])
+ # Test prompts
+ prompts_result = streamable_http_client.list_prompts_sync()
+ assert len(prompts_result.prompts) >= 2
+
+ greeting_result = streamable_http_client.get_prompt_sync("greeting_prompt", {"name": "Charlie"})
+ assert len(greeting_result.messages) > 0
+ prompt_text = greeting_result.messages[0].content.text
+ assert "Hello, Charlie!" in prompt_text
+
def _messages_to_content_blocks(messages: List[Message]) -> List[ToolUse]:
return [block["toolUse"] for message in messages for block in message["content"] if "toolUse" in block]
diff --git a/tests_integ/test_mcp_client_structured_content_with_hooks.py b/tests_integ/test_mcp_client_structured_content_with_hooks.py
new file mode 100644
index 000000000..ca2468c48
--- /dev/null
+++ b/tests_integ/test_mcp_client_structured_content_with_hooks.py
@@ -0,0 +1,65 @@
+"""Integration test demonstrating hooks system with MCP client structured content tool.
+
+This test shows how to use the hooks system to capture and inspect tool invocation
+results, specifically testing the echo_with_structured_content tool from echo_server.
+"""
+
+import json
+
+from mcp import StdioServerParameters, stdio_client
+
+from strands import Agent
+from strands.experimental.hooks import AfterToolInvocationEvent
+from strands.hooks import HookProvider, HookRegistry
+from strands.tools.mcp.mcp_client import MCPClient
+
+
+class StructuredContentHookProvider(HookProvider):
+ """Hook provider that captures structured content tool results."""
+
+ def __init__(self):
+ self.captured_result = None
+
+ def register_hooks(self, registry: HookRegistry) -> None:
+ """Register callback for after tool invocation events."""
+ registry.add_callback(AfterToolInvocationEvent, self.on_after_tool_invocation)
+
+ def on_after_tool_invocation(self, event: AfterToolInvocationEvent) -> None:
+ """Capture structured content tool results."""
+ if event.tool_use["name"] == "echo_with_structured_content":
+ self.captured_result = event.result
+
+
+def test_mcp_client_hooks_structured_content():
+ """Test using hooks to inspect echo_with_structured_content tool result."""
+ # Create hook provider to capture tool result
+ hook_provider = StructuredContentHookProvider()
+
+ # Set up MCP client for echo server
+ stdio_mcp_client = MCPClient(
+ lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"]))
+ )
+
+ with stdio_mcp_client:
+ # Create agent with MCP tools and hook provider
+ agent = Agent(tools=stdio_mcp_client.list_tools_sync(), hooks=[hook_provider])
+
+ # Test structured content functionality
+ test_data = "HOOKS_TEST_DATA"
+ agent(f"Use the echo_with_structured_content tool to echo: {test_data}")
+
+ # Verify hook captured the tool result
+ assert hook_provider.captured_result is not None
+ result = hook_provider.captured_result
+
+ # Verify basic result structure
+ assert result["status"] == "success"
+ assert len(result["content"]) == 1
+
+ # Verify structured content is present and correct
+ assert "structuredContent" in result
+ assert result["structuredContent"]["result"] == {"echoed": test_data}
+
+ # Verify text content matches structured content
+ text_content = json.loads(result["content"][0]["text"])
+ assert text_content == {"echoed": test_data}