Skip to content

Commit 89d261e

Browse files
authored
models - move abstract class (#409)
1 parent ca4a567 commit 89d261e

File tree

21 files changed

+587
-676
lines changed

21 files changed

+587
-676
lines changed

src/strands/agent/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@
3030
MessageAddedEvent,
3131
)
3232
from ..models.bedrock import BedrockModel
33+
from ..models.model import Model
3334
from ..telemetry.metrics import EventLoopMetrics
3435
from ..telemetry.tracer import get_tracer
3536
from ..tools.registry import ToolRegistry
3637
from ..tools.watcher import ToolWatcher
3738
from ..types.content import ContentBlock, Message, Messages
3839
from ..types.exceptions import ContextWindowOverflowException
39-
from ..types.models import Model
4040
from ..types.tools import ToolResult, ToolUse
4141
from ..types.traces import AttributeValue
4242
from .agent_result import AgentResult

src/strands/event_loop/streaming.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import logging
55
from typing import Any, AsyncGenerator, AsyncIterable, Optional
66

7+
from ..models.model import Model
78
from ..types.content import ContentBlock, Message, Messages
8-
from ..types.models import Model
99
from ..types.streaming import (
1010
ContentBlockDeltaEvent,
1111
ContentBlockStart,

src/strands/models/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
This package includes an abstract base Model class along with concrete implementations for specific providers.
44
"""
55

6-
from . import bedrock
6+
from . import bedrock, model
77
from .bedrock import BedrockModel
8+
from .model import Model
89

9-
__all__ = ["bedrock", "BedrockModel"]
10+
__all__ = ["bedrock", "model", "BedrockModel", "Model"]

src/strands/models/anthropic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
from ..tools import convert_pydantic_to_tool_spec
1818
from ..types.content import ContentBlock, Messages
1919
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
20-
from ..types.models import Model
2120
from ..types.streaming import StreamEvent
2221
from ..types.tools import ToolSpec
22+
from .model import Model
2323

2424
logger = logging.getLogger(__name__)
2525

@@ -361,7 +361,7 @@ async def stream(
361361
"""
362362
logger.debug("formatting request")
363363
request = self.format_request(messages, tool_specs, system_prompt)
364-
logger.debug("formatted request=<%s>", request)
364+
logger.debug("request=<%s>", request)
365365

366366
logger.debug("invoking model")
367367
try:

src/strands/models/bedrock.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
from pydantic import BaseModel
1717
from typing_extensions import TypedDict, Unpack, override
1818

19-
from ..event_loop.streaming import process_stream
19+
from ..event_loop import streaming
2020
from ..tools import convert_pydantic_to_tool_spec
2121
from ..types.content import Messages
2222
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
23-
from ..types.models import Model
2423
from ..types.streaming import StreamEvent
2524
from ..types.tools import ToolSpec
25+
from .model import Model
2626

2727
logger = logging.getLogger(__name__)
2828

@@ -374,7 +374,7 @@ def _stream(
374374
"""
375375
logger.debug("formatting request")
376376
request = self.format_request(messages, tool_specs, system_prompt)
377-
logger.debug("formatted request=<%s>", request)
377+
logger.debug("request=<%s>", request)
378378

379379
logger.debug("invoking model")
380380
streaming = self.config.get("streaming", True)
@@ -577,7 +577,7 @@ async def structured_output(
577577
tool_spec = convert_pydantic_to_tool_spec(output_model)
578578

579579
response = self.stream(messages=prompt, tool_specs=[tool_spec])
580-
async for event in process_stream(response, prompt):
580+
async for event in streaming.process_stream(response, prompt):
581581
yield event
582582

583583
stop_reason, messages, _, _ = event["stop"]

src/strands/models/litellm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
from typing_extensions import Unpack, override
1414

1515
from ..types.content import ContentBlock, Messages
16-
from ..types.models.openai import OpenAIModel
1716
from ..types.streaming import StreamEvent
1817
from ..types.tools import ToolSpec
18+
from .openai import OpenAIModel
1919

2020
logger = logging.getLogger(__name__)
2121

@@ -119,7 +119,7 @@ async def stream(
119119
"""
120120
logger.debug("formatting request")
121121
request = self.format_request(messages, tool_specs, system_prompt)
122-
logger.debug("formatted request=<%s>", request)
122+
logger.debug("request=<%s>", request)
123123

124124
logger.debug("invoking model")
125125
response = await litellm.acompletion(**self.client_args, **request)

src/strands/models/llamaapi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717

1818
from ..types.content import ContentBlock, Messages
1919
from ..types.exceptions import ModelThrottledException
20-
from ..types.models import Model
2120
from ..types.streaming import StreamEvent, Usage
2221
from ..types.tools import ToolResult, ToolSpec, ToolUse
22+
from .model import Model
2323

2424
logger = logging.getLogger(__name__)
2525

@@ -340,7 +340,7 @@ async def stream(
340340
"""
341341
logger.debug("formatting request")
342342
request = self.format_request(messages, tool_specs, system_prompt)
343-
logger.debug("formatted request=<%s>", request)
343+
logger.debug("request=<%s>", request)
344344

345345
logger.debug("invoking model")
346346
try:

src/strands/models/mistral.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
from ..types.content import ContentBlock, Messages
1616
from ..types.exceptions import ModelThrottledException
17-
from ..types.models import Model
1817
from ..types.streaming import StopReason, StreamEvent
1918
from ..types.tools import ToolResult, ToolSpec, ToolUse
19+
from .model import Model
2020

2121
logger = logging.getLogger(__name__)
2222

@@ -409,7 +409,7 @@ async def stream(
409409
"""
410410
logger.debug("formatting request")
411411
request = self.format_request(messages, tool_specs, system_prompt)
412-
logger.debug("formatted request=<%s>", request)
412+
logger.debug("request=<%s>", request)
413413

414414
logger.debug("invoking model")
415415
try:

src/strands/types/models/model.py renamed to src/strands/models/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
1-
"""Model-related type definitions for the SDK."""
1+
"""Abstract base class for Agent model providers."""
22

33
import abc
44
import logging
55
from typing import Any, AsyncGenerator, AsyncIterable, Optional, Type, TypeVar, Union
66

77
from pydantic import BaseModel
88

9-
from ..content import Messages
10-
from ..streaming import StreamEvent
11-
from ..tools import ToolSpec
9+
from ..types.content import Messages
10+
from ..types.streaming import StreamEvent
11+
from ..types.tools import ToolSpec
1212

1313
logger = logging.getLogger(__name__)
1414

1515
T = TypeVar("T", bound=BaseModel)
1616

1717

1818
class Model(abc.ABC):
19-
"""Abstract base class for AI model implementations.
19+
"""Abstract base class for Agent model providers.
2020
2121
This class defines the interface for all model implementations in the Strands Agents SDK. It provides a
2222
standardized way to configure and process requests for different AI model providers.

src/strands/models/ollama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
from typing_extensions import TypedDict, Unpack, override
1313

1414
from ..types.content import ContentBlock, Messages
15-
from ..types.models import Model
1615
from ..types.streaming import StopReason, StreamEvent
1716
from ..types.tools import ToolSpec
17+
from .model import Model
1818

1919
logger = logging.getLogger(__name__)
2020

@@ -296,7 +296,7 @@ async def stream(
296296
"""
297297
logger.debug("formatting request")
298298
request = self.format_request(messages, tool_specs, system_prompt)
299-
logger.debug("formatted request=<%s>", request)
299+
logger.debug("request=<%s>", request)
300300

301301
logger.debug("invoking model")
302302
tool_requested = False

0 commit comments

Comments
 (0)