@@ -385,18 +385,42 @@ def __init__(
385385 self .state = GraphState ()
386386 self .tracer = get_tracer ()
387387
388- def __call__ (self , task : str | list [ContentBlock ], ** kwargs : Any ) -> GraphResult :
389- """Invoke the graph synchronously."""
388+ def __call__ (
389+ self , task : str | list [ContentBlock ], invocation_state : dict [str , Any ] | None = None , ** kwargs : Any
390+ ) -> GraphResult :
391+ """Invoke the graph synchronously.
392+
393+ Args:
394+ task: The task to execute
395+ invocation_state: Additional state/context passed to underlying agents.
396+ Defaults to None to avoid mutable default argument issues.
397+ **kwargs: Keyword arguments allowing backward compatible future changes.
398+ """
399+ if invocation_state is None :
400+ invocation_state = {}
390401
391402 def execute () -> GraphResult :
392- return asyncio .run (self .invoke_async (task ))
403+ return asyncio .run (self .invoke_async (task , invocation_state ))
393404
394405 with ThreadPoolExecutor () as executor :
395406 future = executor .submit (execute )
396407 return future .result ()
397408
398- async def invoke_async (self , task : str | list [ContentBlock ], ** kwargs : Any ) -> GraphResult :
399- """Invoke the graph asynchronously."""
409+ async def invoke_async (
410+ self , task : str | list [ContentBlock ], invocation_state : dict [str , Any ] | None = None , ** kwargs : Any
411+ ) -> GraphResult :
412+ """Invoke the graph asynchronously.
413+
414+ Args:
415+ task: The task to execute
416+ invocation_state: Additional state/context passed to underlying agents.
417+ Defaults to None to avoid mutable default argument issues - a new empty dict
418+ is created if None is provided.
419+ **kwargs: Keyword arguments allowing backward compatible future changes.
420+ """
421+ if invocation_state is None :
422+ invocation_state = {}
423+
400424 logger .debug ("task=<%s> | starting graph execution" , task )
401425
402426 # Initialize state
@@ -420,7 +444,7 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> G
420444 self .node_timeout or "None" ,
421445 )
422446
423- await self ._execute_graph ()
447+ await self ._execute_graph (invocation_state )
424448
425449 # Set final status based on execution results
426450 if self .state .failed_nodes :
@@ -450,7 +474,7 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
450474 # Validate Agent-specific constraints for each node
451475 _validate_node_executor (node .executor )
452476
453- async def _execute_graph (self ) -> None :
477+ async def _execute_graph (self , invocation_state : dict [ str , Any ] ) -> None :
454478 """Unified execution flow with conditional routing."""
455479 ready_nodes = list (self .entry_points )
456480
@@ -469,7 +493,7 @@ async def _execute_graph(self) -> None:
469493 ready_nodes .clear ()
470494
471495 # Execute current batch of ready nodes concurrently
472- tasks = [asyncio .create_task (self ._execute_node (node )) for node in current_batch ]
496+ tasks = [asyncio .create_task (self ._execute_node (node , invocation_state )) for node in current_batch ]
473497
474498 for task in tasks :
475499 await task
@@ -506,7 +530,7 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[
506530 )
507531 return False
508532
509- async def _execute_node (self , node : GraphNode ) -> None :
533+ async def _execute_node (self , node : GraphNode , invocation_state : dict [ str , Any ] ) -> None :
510534 """Execute a single node with error handling and timeout protection."""
511535 # Reset the node's state if reset_on_revisit is enabled and it's being revisited
512536 if self .reset_on_revisit and node in self .state .completed_nodes :
@@ -529,11 +553,11 @@ async def _execute_node(self, node: GraphNode) -> None:
529553 if isinstance (node .executor , MultiAgentBase ):
530554 if self .node_timeout is not None :
531555 multi_agent_result = await asyncio .wait_for (
532- node .executor .invoke_async (node_input ),
556+ node .executor .invoke_async (node_input , invocation_state ),
533557 timeout = self .node_timeout ,
534558 )
535559 else :
536- multi_agent_result = await node .executor .invoke_async (node_input )
560+ multi_agent_result = await node .executor .invoke_async (node_input , invocation_state )
537561
538562 # Create NodeResult with MultiAgentResult directly
539563 node_result = NodeResult (
@@ -548,11 +572,11 @@ async def _execute_node(self, node: GraphNode) -> None:
548572 elif isinstance (node .executor , Agent ):
549573 if self .node_timeout is not None :
550574 agent_response = await asyncio .wait_for (
551- node .executor .invoke_async (node_input ),
575+ node .executor .invoke_async (node_input , ** invocation_state ),
552576 timeout = self .node_timeout ,
553577 )
554578 else :
555- agent_response = await node .executor .invoke_async (node_input )
579+ agent_response = await node .executor .invoke_async (node_input , ** invocation_state )
556580
557581 # Extract metrics from agent response
558582 usage = Usage (inputTokens = 0 , outputTokens = 0 , totalTokens = 0 )
0 commit comments