33- Docs: https://aws.amazon.com/bedrock/
44"""
55
6+ import json
67import logging
78import os
8- from typing import Any , Iterable , Literal , Optional , cast
9+ from typing import Any , Iterable , List , Literal , Optional , cast
910
1011import boto3
1112from botocore .config import Config as BotocoreConfig
12- from botocore .exceptions import ClientError , EventStreamError
13+ from botocore .exceptions import ClientError
1314from typing_extensions import TypedDict , Unpack , override
1415
1516from ..types .content import Messages
@@ -61,6 +62,7 @@ class BedrockConfig(TypedDict, total=False):
6162 max_tokens: Maximum number of tokens to generate in the response
6263 model_id: The Bedrock model ID (e.g., "us.anthropic.claude-3-7-sonnet-20250219-v1:0")
6364 stop_sequences: List of sequences that will stop generation when encountered
65+ streaming: Flag to enable/disable streaming. Defaults to True.
6466 temperature: Controls randomness in generation (higher = more random)
6567 top_p: Controls diversity via nucleus sampling (alternative to temperature)
6668 """
@@ -81,6 +83,7 @@ class BedrockConfig(TypedDict, total=False):
8183 max_tokens : Optional [int ]
8284 model_id : str
8385 stop_sequences : Optional [list [str ]]
86+ streaming : Optional [bool ]
8487 temperature : Optional [float ]
8588 top_p : Optional [float ]
8689
@@ -246,11 +249,68 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
246249 """
247250 return cast (StreamEvent , event )
248251
252+ def _has_blocked_guardrail (self , guardrail_data : dict [str , Any ]) -> bool :
253+ """Check if guardrail data contains any blocked policies.
254+
255+ Args:
256+ guardrail_data: Guardrail data from trace information.
257+
258+ Returns:
259+ True if any blocked guardrail is detected, False otherwise.
260+ """
261+ input_assessment = guardrail_data .get ("inputAssessment" , {})
262+ output_assessments = guardrail_data .get ("outputAssessments" , {})
263+
264+ # Check input assessments
265+ if any (self ._find_detected_and_blocked_policy (assessment ) for assessment in input_assessment .values ()):
266+ return True
267+
268+ # Check output assessments
269+ if any (self ._find_detected_and_blocked_policy (assessment ) for assessment in output_assessments .values ()):
270+ return True
271+
272+ return False
273+
274+ def _generate_redaction_events (self ) -> list [StreamEvent ]:
275+ """Generate redaction events based on configuration.
276+
277+ Returns:
278+ List of redaction events to yield.
279+ """
280+ events : List [StreamEvent ] = []
281+
282+ if self .config .get ("guardrail_redact_input" , True ):
283+ logger .debug ("Redacting user input due to guardrail." )
284+ events .append (
285+ {
286+ "redactContent" : {
287+ "redactUserContentMessage" : self .config .get (
288+ "guardrail_redact_input_message" , "[User input redacted.]"
289+ )
290+ }
291+ }
292+ )
293+
294+ if self .config .get ("guardrail_redact_output" , False ):
295+ logger .debug ("Redacting assistant output due to guardrail." )
296+ events .append (
297+ {
298+ "redactContent" : {
299+ "redactAssistantContentMessage" : self .config .get (
300+ "guardrail_redact_output_message" , "[Assistant output redacted.]"
301+ )
302+ }
303+ }
304+ )
305+
306+ return events
307+
249308 @override
250- def stream (self , request : dict [str , Any ]) -> Iterable [dict [ str , Any ] ]:
251- """Send the request to the Bedrock model and get the streaming response.
309+ def stream (self , request : dict [str , Any ]) -> Iterable [StreamEvent ]:
310+ """Send the request to the Bedrock model and get the response.
252311
253- This method calls the Bedrock converse_stream API and returns the stream of response events.
312+ This method calls either the Bedrock converse_stream API or the converse API
313+ based on the streaming parameter in the configuration.
254314
255315 Args:
256316 request: The formatted request to send to the Bedrock model
@@ -260,63 +320,132 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
260320
261321 Raises:
262322 ContextWindowOverflowException: If the input exceeds the model's context window.
263- EventStreamError: For all other Bedrock API errors .
323+ ModelThrottledException: If the model service is throttling requests .
264324 """
325+ streaming = self .config .get ("streaming" , True )
326+
265327 try :
266- response = self .client .converse_stream (** request )
267- for chunk in response ["stream" ]:
268- if self .config .get ("guardrail_redact_input" , True ) or self .config .get ("guardrail_redact_output" , False ):
328+ if streaming :
329+ # Streaming implementation
330+ response = self .client .converse_stream (** request )
331+ for chunk in response ["stream" ]:
269332 if (
270333 "metadata" in chunk
271334 and "trace" in chunk ["metadata" ]
272335 and "guardrail" in chunk ["metadata" ]["trace" ]
273336 ):
274- inputAssessment = chunk ["metadata" ]["trace" ]["guardrail" ].get ("inputAssessment" , {})
275- outputAssessments = chunk ["metadata" ]["trace" ]["guardrail" ].get ("outputAssessments" , {})
276-
277- # Check if an input or output guardrail was triggered
278- if any (
279- self ._find_detected_and_blocked_policy (assessment )
280- for assessment in inputAssessment .values ()
281- ) or any (
282- self ._find_detected_and_blocked_policy (assessment )
283- for assessment in outputAssessments .values ()
284- ):
285- if self .config .get ("guardrail_redact_input" , True ):
286- logger .debug ("Found blocked input guardrail. Redacting input." )
287- yield {
288- "redactContent" : {
289- "redactUserContentMessage" : self .config .get (
290- "guardrail_redact_input_message" , "[User input redacted.]"
291- )
292- }
293- }
294- if self .config .get ("guardrail_redact_output" , False ):
295- logger .debug ("Found blocked output guardrail. Redacting output." )
296- yield {
297- "redactContent" : {
298- "redactAssistantContentMessage" : self .config .get (
299- "guardrail_redact_output_message" , "[Assistant output redacted.]"
300- )
301- }
302- }
337+ guardrail_data = chunk ["metadata" ]["trace" ]["guardrail" ]
338+ if self ._has_blocked_guardrail (guardrail_data ):
339+ yield from self ._generate_redaction_events ()
340+ yield chunk
341+ else :
342+ # Non-streaming implementation
343+ response = self .client .converse (** request )
344+
345+ # Convert and yield from the response
346+ yield from self ._convert_non_streaming_to_streaming (response )
303347
304- yield chunk
305- except EventStreamError as e :
306- # Handle throttling that occurs mid-stream?
307- if "ThrottlingException" in str (e ) and "ConverseStream" in str (e ):
308- raise ModelThrottledException (str (e )) from e
348+ # Check for guardrail triggers after yielding any events (same as streaming path)
349+ if (
350+ "trace" in response
351+ and "guardrail" in response ["trace" ]
352+ and self ._has_blocked_guardrail (response ["trace" ]["guardrail" ])
353+ ):
354+ yield from self ._generate_redaction_events ()
309355
310- if any (overflow_message in str (e ) for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES ):
356+ except ClientError as e :
357+ error_message = str (e )
358+
359+ # Handle throttling error
360+ if e .response ["Error" ]["Code" ] == "ThrottlingException" :
361+ raise ModelThrottledException (error_message ) from e
362+
363+ # Handle context window overflow
364+ if any (overflow_message in error_message for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES ):
311365 logger .warning ("bedrock threw context window overflow error" )
312366 raise ContextWindowOverflowException (e ) from e
367+
368+ # Otherwise raise the error
313369 raise e
314- except ClientError as e :
315- # Handle throttling that occurs at the beginning of the call
316- if e .response ["Error" ]["Code" ] == "ThrottlingException" :
317- raise ModelThrottledException (str (e )) from e
318370
319- raise
371+ def _convert_non_streaming_to_streaming (self , response : dict [str , Any ]) -> Iterable [StreamEvent ]:
372+ """Convert a non-streaming response to the streaming format.
373+
374+ Args:
375+ response: The non-streaming response from the Bedrock model.
376+
377+ Returns:
378+ An iterable of response events in the streaming format.
379+ """
380+ # Yield messageStart event
381+ yield {"messageStart" : {"role" : response ["output" ]["message" ]["role" ]}}
382+
383+ # Process content blocks
384+ for content in response ["output" ]["message" ]["content" ]:
385+ # Yield contentBlockStart event if needed
386+ if "toolUse" in content :
387+ yield {
388+ "contentBlockStart" : {
389+ "start" : {
390+ "toolUse" : {
391+ "toolUseId" : content ["toolUse" ]["toolUseId" ],
392+ "name" : content ["toolUse" ]["name" ],
393+ }
394+ },
395+ }
396+ }
397+
398+ # For tool use, we need to yield the input as a delta
399+ input_value = json .dumps (content ["toolUse" ]["input" ])
400+
401+ yield {"contentBlockDelta" : {"delta" : {"toolUse" : {"input" : input_value }}}}
402+ elif "text" in content :
403+ # Then yield the text as a delta
404+ yield {
405+ "contentBlockDelta" : {
406+ "delta" : {"text" : content ["text" ]},
407+ }
408+ }
409+ elif "reasoningContent" in content :
410+ # Then yield the reasoning content as a delta
411+ yield {
412+ "contentBlockDelta" : {
413+ "delta" : {"reasoningContent" : {"text" : content ["reasoningContent" ]["reasoningText" ]["text" ]}}
414+ }
415+ }
416+
417+ if "signature" in content ["reasoningContent" ]["reasoningText" ]:
418+ yield {
419+ "contentBlockDelta" : {
420+ "delta" : {
421+ "reasoningContent" : {
422+ "signature" : content ["reasoningContent" ]["reasoningText" ]["signature" ]
423+ }
424+ }
425+ }
426+ }
427+
428+ # Yield contentBlockStop event
429+ yield {"contentBlockStop" : {}}
430+
431+ # Yield messageStop event
432+ yield {
433+ "messageStop" : {
434+ "stopReason" : response ["stopReason" ],
435+ "additionalModelResponseFields" : response .get ("additionalModelResponseFields" ),
436+ }
437+ }
438+
439+ # Yield metadata event
440+ if "usage" in response or "metrics" in response or "trace" in response :
441+ metadata : StreamEvent = {"metadata" : {}}
442+ if "usage" in response :
443+ metadata ["metadata" ]["usage" ] = response ["usage" ]
444+ if "metrics" in response :
445+ metadata ["metadata" ]["metrics" ] = response ["metrics" ]
446+ if "trace" in response :
447+ metadata ["metadata" ]["trace" ] = response ["trace" ]
448+ yield metadata
320449
321450 def _find_detected_and_blocked_policy (self , input : Any ) -> bool :
322451 """Recursively checks if the assessment contains a detected and blocked guardrail.
0 commit comments