Skip to content

Commit 64d61e0

Browse files
authored
feat: add region-aware default model ID for Bedrock (#835)
These changes introduce region-aware default model ID functionality for Bedrock, formatting based on region prefixes, warnings for unsupported regions, and preservation of custom model IDs. Comprehensive test coverage was added, and existing tests were updated. We also maintain compatibility for two key use cases: preserving customer-overridden model IDs and maintaining compatibility with existing DEFAULT_BEDROCK_MODEL_ID usage patterns.
1 parent 7f58ce9 commit 64d61e0

File tree

3 files changed

+152
-9
lines changed

3 files changed

+152
-9
lines changed

src/strands/models/bedrock.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import json
88
import logging
99
import os
10+
import warnings
1011
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, cast
1112

1213
import boto3
@@ -29,7 +30,9 @@
2930

3031
logger = logging.getLogger(__name__)
3132

33+
# See: `BedrockModel._get_default_model_with_warning` for why we need both
3234
DEFAULT_BEDROCK_MODEL_ID = "us.anthropic.claude-sonnet-4-20250514-v1:0"
35+
_DEFAULT_BEDROCK_MODEL_ID = "{}.anthropic.claude-sonnet-4-20250514-v1:0"
3336
DEFAULT_BEDROCK_REGION = "us-west-2"
3437

3538
BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES = [
@@ -47,6 +50,7 @@
4750

4851
DEFAULT_READ_TIMEOUT = 120
4952

53+
5054
class BedrockModel(Model):
5155
"""AWS Bedrock model provider implementation.
5256
@@ -129,13 +133,16 @@ def __init__(
129133
if region_name and boto_session:
130134
raise ValueError("Cannot specify both `region_name` and `boto_session`.")
131135

132-
self.config = BedrockModel.BedrockConfig(model_id=DEFAULT_BEDROCK_MODEL_ID, include_tool_result_status="auto")
136+
session = boto_session or boto3.Session()
137+
resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION
138+
self.config = BedrockModel.BedrockConfig(
139+
model_id=BedrockModel._get_default_model_with_warning(resolved_region, model_config),
140+
include_tool_result_status="auto",
141+
)
133142
self.update_config(**model_config)
134143

135144
logger.debug("config=<%s> | initializing", self.config)
136145

137-
session = boto_session or boto3.Session()
138-
139146
# Add strands-agents to the request user agent
140147
if boto_client_config:
141148
existing_user_agent = getattr(boto_client_config, "user_agent_extra", None)
@@ -150,8 +157,6 @@ def __init__(
150157
else:
151158
client_config = BotocoreConfig(user_agent_extra="strands-agents", read_timeout=DEFAULT_READ_TIMEOUT)
152159

153-
resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION
154-
155160
self.client = session.client(
156161
service_name="bedrock-runtime",
157162
config=client_config,
@@ -770,3 +775,46 @@ async def structured_output(
770775
raise ValueError("No valid tool use or tool use input was found in the Bedrock response.")
771776

772777
yield {"output": output_model(**output_response)}
778+
779+
@staticmethod
780+
def _get_default_model_with_warning(region_name: str, model_config: Optional[BedrockConfig] = None) -> str:
781+
"""Get the default Bedrock modelId based on region.
782+
783+
If the region is not **known** to support inference then we show a helpful warning
784+
that compliments the exception that Bedrock will throw.
785+
If the customer provided a model_id in their config or they overrode the `DEFAULT_BEDROCK_MODEL_ID`
786+
then we should not process further.
787+
788+
Args:
789+
region_name (str): region for bedrock model
790+
model_config (Optional[dict[str, Any]]): Model Config that caller passes in on init
791+
"""
792+
if DEFAULT_BEDROCK_MODEL_ID != _DEFAULT_BEDROCK_MODEL_ID.format("us"):
793+
return DEFAULT_BEDROCK_MODEL_ID
794+
795+
model_config = model_config or {}
796+
if model_config.get("model_id"):
797+
return model_config["model_id"]
798+
799+
prefix_inference_map = {"ap": "apac"} # some inference endpoints can be a bit different than the region prefix
800+
801+
prefix = "-".join(region_name.split("-")[:-2]).lower() # handles `us-east-1` or `us-gov-east-1`
802+
if prefix not in {"us", "eu", "ap", "us-gov"}:
803+
warnings.warn(
804+
f"""
805+
================== WARNING ==================
806+
807+
This region {region_name} does not support
808+
our default inference endpoint: {_DEFAULT_BEDROCK_MODEL_ID.format(prefix)}.
809+
Update the agent to pass in a 'model_id' like so:
810+
```
811+
Agent(..., model='valid_model_id', ...)
812+
````
813+
Documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
814+
815+
==================================================
816+
""",
817+
stacklevel=2,
818+
)
819+
820+
return _DEFAULT_BEDROCK_MODEL_ID.format(prefix_inference_map.get(prefix, prefix))

tests/strands/agent/test_agent.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
from tests.fixtures.mock_session_repository import MockedSessionRepository
2727
from tests.fixtures.mocked_model_provider import MockedModelProvider
2828

29+
# For unit testing we will use the the us inference
30+
FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID.format("us")
31+
2932

3033
@pytest.fixture
3134
def mock_randint():
@@ -211,7 +214,7 @@ def test_agent__init__with_default_model():
211214
agent = Agent()
212215

213216
assert isinstance(agent.model, BedrockModel)
214-
assert agent.model.config["model_id"] == DEFAULT_BEDROCK_MODEL_ID
217+
assert agent.model.config["model_id"] == FORMATTED_DEFAULT_MODEL_ID
215218

216219

217220
def test_agent__init__with_explicit_model(mock_model):
@@ -891,7 +894,7 @@ def test_agent__del__(agent):
891894
def test_agent_init_with_no_model_or_model_id():
892895
agent = Agent()
893896
assert agent.model is not None
894-
assert agent.model.get_config().get("model_id") == DEFAULT_BEDROCK_MODEL_ID
897+
assert agent.model.get_config().get("model_id") == FORMATTED_DEFAULT_MODEL_ID
895898

896899

897900
def test_agent_tool_no_parameter_conflict(agent, tool_registry, mock_randint, agenerator):

tests/strands/models/test_bedrock.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,17 @@
1111

1212
import strands
1313
from strands.models import BedrockModel
14-
from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION, DEFAULT_READ_TIMEOUT
14+
from strands.models.bedrock import (
15+
_DEFAULT_BEDROCK_MODEL_ID,
16+
DEFAULT_BEDROCK_MODEL_ID,
17+
DEFAULT_BEDROCK_REGION,
18+
DEFAULT_READ_TIMEOUT,
19+
)
1520
from strands.types.exceptions import ModelThrottledException
1621
from strands.types.tools import ToolSpec
1722

23+
FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID.format("us")
24+
1825

1926
@pytest.fixture
2027
def session_cls():
@@ -119,7 +126,7 @@ def test__init__default_model_id(bedrock_client):
119126
model = BedrockModel()
120127

121128
tru_model_id = model.get_config().get("model_id")
122-
exp_model_id = DEFAULT_BEDROCK_MODEL_ID
129+
exp_model_id = FORMATTED_DEFAULT_MODEL_ID
123130

124131
assert tru_model_id == exp_model_id
125132

@@ -1543,3 +1550,88 @@ def test_tool_choice_none_no_warning(model, messages, captured_warnings):
15431550
model.format_request(messages, tool_choice=None)
15441551

15451552
assert len(captured_warnings) == 0
1553+
1554+
1555+
def test_get_default_model_with_warning_supported_regions_shows_no_warning(captured_warnings):
1556+
"""Test get_model_prefix_with_warning doesn't warn for supported region prefixes."""
1557+
BedrockModel._get_default_model_with_warning("us-west-2")
1558+
BedrockModel._get_default_model_with_warning("eu-west-2")
1559+
assert len(captured_warnings) == 0
1560+
1561+
1562+
def test_get_default_model_for_supported_eu_region_returns_correct_model_id(captured_warnings):
1563+
model_id = BedrockModel._get_default_model_with_warning("eu-west-1")
1564+
assert model_id == "eu.anthropic.claude-sonnet-4-20250514-v1:0"
1565+
assert len(captured_warnings) == 0
1566+
1567+
1568+
def test_get_default_model_for_supported_us_region_returns_correct_model_id(captured_warnings):
1569+
model_id = BedrockModel._get_default_model_with_warning("us-east-1")
1570+
assert model_id == "us.anthropic.claude-sonnet-4-20250514-v1:0"
1571+
assert len(captured_warnings) == 0
1572+
1573+
1574+
def test_get_default_model_for_supported_gov_region_returns_correct_model_id(captured_warnings):
1575+
model_id = BedrockModel._get_default_model_with_warning("us-gov-west-1")
1576+
assert model_id == "us-gov.anthropic.claude-sonnet-4-20250514-v1:0"
1577+
assert len(captured_warnings) == 0
1578+
1579+
1580+
def test_get_model_prefix_for_ap_region_converts_to_apac_endpoint(captured_warnings):
1581+
"""Test _get_default_model_with_warning warns for APAC regions since 'ap' is not in supported prefixes."""
1582+
model_id = BedrockModel._get_default_model_with_warning("ap-southeast-1")
1583+
assert model_id == "apac.anthropic.claude-sonnet-4-20250514-v1:0"
1584+
1585+
1586+
def test_get_default_model_with_warning_unsupported_region_warns(captured_warnings):
1587+
"""Test _get_default_model_with_warning warns for unsupported regions."""
1588+
BedrockModel._get_default_model_with_warning("ca-central-1")
1589+
assert len(captured_warnings) == 1
1590+
assert "This region ca-central-1 does not support" in str(captured_warnings[0].message)
1591+
assert "our default inference endpoint" in str(captured_warnings[0].message)
1592+
1593+
1594+
def test_get_default_model_with_warning_no_warning_with_custom_model_id(captured_warnings):
1595+
"""Test _get_default_model_with_warning doesn't warn when custom model_id provided."""
1596+
model_config = {"model_id": "custom-model"}
1597+
model_id = BedrockModel._get_default_model_with_warning("ca-central-1", model_config)
1598+
1599+
assert model_id == "custom-model"
1600+
assert len(captured_warnings) == 0
1601+
1602+
1603+
def test_init_with_unsupported_region_warns(session_cls, captured_warnings):
1604+
"""Test BedrockModel initialization warns for unsupported regions."""
1605+
BedrockModel(region_name="ca-central-1")
1606+
1607+
assert len(captured_warnings) == 1
1608+
assert "This region ca-central-1 does not support" in str(captured_warnings[0].message)
1609+
1610+
1611+
def test_init_with_unsupported_region_custom_model_no_warning(session_cls, captured_warnings):
1612+
"""Test BedrockModel initialization doesn't warn when custom model_id provided."""
1613+
BedrockModel(region_name="ca-central-1", model_id="custom-model")
1614+
assert len(captured_warnings) == 0
1615+
1616+
1617+
def test_override_default_model_id_uses_the_overriden_value(captured_warnings):
1618+
with unittest.mock.patch("strands.models.bedrock.DEFAULT_BEDROCK_MODEL_ID", "custom-overridden-model"):
1619+
model_id = BedrockModel._get_default_model_with_warning("us-east-1")
1620+
assert model_id == "custom-overridden-model"
1621+
1622+
1623+
def test_no_override_uses_formatted_default_model_id(captured_warnings):
1624+
model_id = BedrockModel._get_default_model_with_warning("us-east-1")
1625+
assert model_id == "us.anthropic.claude-sonnet-4-20250514-v1:0"
1626+
assert model_id != _DEFAULT_BEDROCK_MODEL_ID
1627+
assert len(captured_warnings) == 0
1628+
1629+
1630+
def test_custom_model_id_not_overridden_by_region_formatting(session_cls):
1631+
"""Test that custom model_id is not overridden by region formatting."""
1632+
custom_model_id = "custom.model.id"
1633+
1634+
model = BedrockModel(model_id=custom_model_id)
1635+
model_id = model.get_config().get("model_id")
1636+
1637+
assert model_id == custom_model_id

0 commit comments

Comments
 (0)