-
Notifications
You must be signed in to change notification settings - Fork 465
swarm - switch to handoff node only after current node stops #1147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -156,6 +156,7 @@ class SwarmState: | |
| # Total metrics across all agents | ||
| accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) | ||
| execution_time: int = 0 # Total execution time in milliseconds | ||
| handoff_node: SwarmNode | None = None # The agent to execute next | ||
| handoff_message: str | None = None # Message passed during agent handoff | ||
|
|
||
| def should_continue( | ||
|
|
@@ -537,7 +538,7 @@ def handoff_to_agent(agent_name: str, message: str, context: dict[str, Any] | No | |
| # Execute handoff | ||
| swarm_ref._handle_handoff(target_node, message, context) | ||
|
|
||
| return {"status": "success", "content": [{"text": f"Handed off to {agent_name}: {message}"}]} | ||
| return {"status": "success", "content": [{"text": f"Handing off to {agent_name}: {message}"}]} | ||
| except Exception as e: | ||
| return {"status": "error", "content": [{"text": f"Error in handoff: {str(e)}"}]} | ||
|
|
||
|
|
@@ -553,21 +554,19 @@ def _handle_handoff(self, target_node: SwarmNode, message: str, context: dict[st | |
| ) | ||
| return | ||
|
|
||
| # Update swarm state | ||
| previous_agent = cast(SwarmNode, self.state.current_node) | ||
| self.state.current_node = target_node | ||
| current_node = cast(SwarmNode, self.state.current_node) | ||
|
|
||
| # Store handoff message for the target agent | ||
| self.state.handoff_node = target_node | ||
| self.state.handoff_message = message | ||
|
|
||
| # Store handoff context as shared context | ||
| if context: | ||
| for key, value in context.items(): | ||
| self.shared_context.add_context(previous_agent, key, value) | ||
| self.shared_context.add_context(current_node, key, value) | ||
|
|
||
| logger.debug( | ||
| "from_node=<%s>, to_node=<%s> | handed off from agent to agent", | ||
| previous_agent.node_id, | ||
| "from_node=<%s>, to_node=<%s> | handing off from agent to agent", | ||
| current_node.node_id, | ||
| target_node.node_id, | ||
| ) | ||
|
|
||
|
|
@@ -667,7 +666,6 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato | |
| logger.debug("reason=<%s> | stopping execution", reason) | ||
| break | ||
|
|
||
| # Get current node | ||
| current_node = self.state.current_node | ||
| if not current_node or current_node.node_id not in self.nodes: | ||
| logger.error("node=<%s> | node not found", current_node.node_id if current_node else "None") | ||
|
|
@@ -680,14 +678,10 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato | |
| len(self.state.node_history) + 1, | ||
| ) | ||
|
|
||
| # Store the current node before execution to detect handoffs | ||
| previous_node = current_node | ||
|
|
||
| # Execute node with timeout protection | ||
| # TODO: Implement cancellation token to stop _execute_node from continuing | ||
| try: | ||
| # Execute with timeout wrapper for async generator streaming | ||
| self.hooks.invoke_callbacks(BeforeNodeCallEvent(self, current_node.node_id, invocation_state)) | ||
|
|
||
| node_stream = self._stream_with_timeout( | ||
| self._execute_node(current_node, self.state.task, invocation_state), | ||
| self.node_timeout, | ||
|
|
@@ -697,28 +691,31 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato | |
| yield event | ||
|
|
||
| self.state.node_history.append(current_node) | ||
|
|
||
| # After self.state add current node, swarm state finish updating, we persist here | ||
| self.hooks.invoke_callbacks(AfterNodeCallEvent(self, current_node.node_id, invocation_state)) | ||
|
||
|
|
||
| logger.debug("node=<%s> | node execution completed", current_node.node_id) | ||
|
|
||
| # Check if handoff occurred during execution | ||
| if self.state.current_node is not None and self.state.current_node != previous_node: | ||
| # Emit handoff event (single node transition in Swarm) | ||
| # Check if handoff requested during execution | ||
| if self.state.handoff_node: | ||
| previous_node = current_node | ||
| current_node = self.state.handoff_node | ||
|
|
||
| self.state.handoff_node = None | ||
| self.state.current_node = current_node | ||
|
|
||
| handoff_event = MultiAgentHandoffEvent( | ||
| from_node_ids=[previous_node.node_id], | ||
| to_node_ids=[self.state.current_node.node_id], | ||
| to_node_ids=[current_node.node_id], | ||
| message=self.state.handoff_message or "Agent handoff occurred", | ||
| ) | ||
| yield handoff_event | ||
| logger.debug( | ||
| "from_node=<%s>, to_node=<%s> | handoff detected", | ||
| previous_node.node_id, | ||
| self.state.current_node.node_id, | ||
| current_node.node_id, | ||
| ) | ||
|
|
||
| else: | ||
| # No handoff occurred, mark swarm as complete | ||
| logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id) | ||
| self.state.completion_status = Status.COMPLETED | ||
| break | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,7 +6,7 @@ | |
|
|
||
| from ..experimental.hooks.multiagent.events import ( | ||
| AfterMultiAgentInvocationEvent, | ||
| AfterNodeCallEvent, | ||
| BeforeNodeCallEvent, | ||
| MultiAgentInitializedEvent, | ||
| ) | ||
| from ..hooks.events import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent | ||
|
|
@@ -44,7 +44,7 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: | |
| registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent)) | ||
|
|
||
| registry.add_callback(MultiAgentInitializedEvent, lambda event: self.initialize_multi_agent(event.source)) | ||
| registry.add_callback(AfterNodeCallEvent, lambda event: self.sync_multi_agent(event.source)) | ||
| registry.add_callback(BeforeNodeCallEvent, lambda event: self.sync_multi_agent(event.source)) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's say we have successfully executed one node and are now executing the handoff node. If we crash on the handoff node, we would be left in different states depending on which event we persist on:
In short, persisting on |
||
| registry.add_callback(AfterMultiAgentInvocationEvent, lambda event: self.sync_multi_agent(event.source)) | ||
|
|
||
| @abstractmethod | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed a few inline comments because I felt the code was already self explanatory.