|
14 | 14 | import logging |
15 | 15 | import random |
16 | 16 | from concurrent.futures import ThreadPoolExecutor |
17 | | -from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast |
| 17 | +from typing import ( |
| 18 | + Any, |
| 19 | + AsyncGenerator, |
| 20 | + AsyncIterator, |
| 21 | + Callable, |
| 22 | + Mapping, |
| 23 | + Optional, |
| 24 | + Type, |
| 25 | + TypeAlias, |
| 26 | + TypeVar, |
| 27 | + Union, |
| 28 | + cast, |
| 29 | +) |
18 | 30 |
|
19 | 31 | from opentelemetry import trace as trace_api |
20 | 32 | from pydantic import BaseModel |
|
55 | 67 | # TypeVar for generic structured output |
56 | 68 | T = TypeVar("T", bound=BaseModel) |
57 | 69 |
|
| 70 | +AgentInput: TypeAlias = str | list[ContentBlock] | Messages | None |
| 71 | + |
58 | 72 |
|
59 | 73 | # Sentinel class and object to distinguish between explicit None and default parameter value |
60 | 74 | class _DefaultCallbackHandlerSentinel: |
@@ -361,7 +375,7 @@ def tool_names(self) -> list[str]: |
361 | 375 | all_tools = self.tool_registry.get_all_tools_config() |
362 | 376 | return list(all_tools.keys()) |
363 | 377 |
|
364 | | - def __call__(self, prompt: str | list[ContentBlock] | Messages | None = None, **kwargs: Any) -> AgentResult: |
| 378 | + def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: |
365 | 379 | """Process a natural language prompt through the agent's event loop. |
366 | 380 |
|
367 | 381 | This method implements the conversational interface with multiple input patterns: |
@@ -394,9 +408,7 @@ def execute() -> AgentResult: |
394 | 408 | future = executor.submit(execute) |
395 | 409 | return future.result() |
396 | 410 |
|
397 | | - async def invoke_async( |
398 | | - self, prompt: str | list[ContentBlock] | Messages | None = None, **kwargs: Any |
399 | | - ) -> AgentResult: |
| 411 | + async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: |
400 | 412 | """Process a natural language prompt through the agent's event loop. |
401 | 413 |
|
402 | 414 | This method implements the conversational interface with multiple input patterns: |
@@ -427,7 +439,7 @@ async def invoke_async( |
427 | 439 |
|
428 | 440 | return cast(AgentResult, event["result"]) |
429 | 441 |
|
430 | | - def structured_output(self, output_model: Type[T], prompt: str | list[ContentBlock] | Messages | None = None) -> T: |
| 442 | + def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> T: |
431 | 443 | """This method allows you to get structured output from the agent. |
432 | 444 |
|
433 | 445 | If you pass in a prompt, it will be used temporarily without adding it to the conversation history. |
@@ -456,9 +468,7 @@ def execute() -> T: |
456 | 468 | future = executor.submit(execute) |
457 | 469 | return future.result() |
458 | 470 |
|
459 | | - async def structured_output_async( |
460 | | - self, output_model: Type[T], prompt: str | list[ContentBlock] | Messages | None = None |
461 | | - ) -> T: |
| 471 | + async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T: |
462 | 472 | """This method allows you to get structured output from the agent. |
463 | 473 |
|
464 | 474 | If you pass in a prompt, it will be used temporarily without adding it to the conversation history. |
@@ -517,7 +527,7 @@ async def structured_output_async( |
517 | 527 |
|
518 | 528 | async def stream_async( |
519 | 529 | self, |
520 | | - prompt: str | list[ContentBlock] | Messages | None = None, |
| 530 | + prompt: AgentInput = None, |
521 | 531 | **kwargs: Any, |
522 | 532 | ) -> AsyncIterator[Any]: |
523 | 533 | """Process a natural language prompt and yield events as an async iterator. |
@@ -657,7 +667,7 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A |
657 | 667 | async for event in events: |
658 | 668 | yield event |
659 | 669 |
|
660 | | - def _convert_prompt_to_messages(self, prompt: str | list[ContentBlock] | Messages | None) -> Messages: |
| 670 | + def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: |
661 | 671 | messages: Messages | None = None |
662 | 672 | if prompt is not None: |
663 | 673 | if isinstance(prompt, str): |
|
0 commit comments