diff --git a/tests/entrypoints/openai/test_serving_responses.py b/tests/entrypoints/openai/test_serving_responses.py index 788a1e912182..93e11b61020c 100644 --- a/tests/entrypoints/openai/test_serving_responses.py +++ b/tests/entrypoints/openai/test_serving_responses.py @@ -34,6 +34,9 @@ def __init__(self): def append_output(self, output) -> None: pass + def append_tool_output(self, output) -> None: + pass + async def call_tool(self): return [] diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index 0041db822080..7a41c668d764 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -80,7 +80,11 @@ def copy(self): class ConversationContext(ABC): @abstractmethod - def append_output(self, output) -> None: + def append_output(self, output: RequestOutput) -> None: + pass + + @abstractmethod + def append_tool_output(self, output) -> None: pass @abstractmethod @@ -151,6 +155,9 @@ def append_output(self, output) -> None: self.num_cached_tokens = output.num_cached_tokens or 0 self.num_output_tokens += len(output.outputs[0].token_ids or []) + def append_tool_output(self, output) -> None: + raise NotImplementedError("Should not be called.") + def need_builtin_tool_call(self) -> bool: return False @@ -205,28 +212,28 @@ def _update_num_reasoning_tokens(self): if self.parser.current_channel in {"analysis", "commentary"}: self.num_reasoning_tokens += 1 - def append_output(self, output: RequestOutput | list[Message]) -> None: - if isinstance(output, RequestOutput): - output_token_ids = output.outputs[0].token_ids - self.parser = get_streamable_parser_for_assistant() - for token_id in output_token_ids: - self.parser.process(token_id) - # Check if the current token is part of reasoning content - self._update_num_reasoning_tokens() - self._update_prefill_token_usage(output) - self._update_decode_token_usage(output) - # Append current turn to all turn list for next turn's calculations - self.all_turn_metrics.append(self.current_turn_metrics.copy()) - self.current_turn_metrics.reset() - # append_output is called only once before tool calling - # in non-streaming case - # so we can append all the parser messages to _messages - output_msgs = self.parser.messages - # The responses finish reason is set in the last message - self.finish_reason = output.outputs[0].finish_reason - else: - # Tool output. - output_msgs = output + def append_output(self, output: RequestOutput) -> None: + output_token_ids = output.outputs[0].token_ids + self.parser = get_streamable_parser_for_assistant() + for token_id in output_token_ids: + self.parser.process(token_id) + # Check if the current token is part of reasoning content + self._update_num_reasoning_tokens() + self._update_prefill_token_usage(output) + self._update_decode_token_usage(output) + # Append current turn to all turn list for next turn's calculations + self.all_turn_metrics.append(self.current_turn_metrics.copy()) + self.current_turn_metrics.reset() + # append_output is called only once before tool calling + # in non-streaming case + # so we can append all the parser messages to _messages + output_msgs = self.parser.messages + # The responses finish reason is set in the last message + self.finish_reason = output.outputs[0].finish_reason + self._messages.extend(output_msgs) + + def append_tool_output(self, output: list[Message]) -> None: + output_msgs = output self._messages.extend(output_msgs) def _update_prefill_token_usage(self, output: RequestOutput) -> None: @@ -502,45 +509,45 @@ def __init__(self, *args, **kwargs): def messages(self) -> list: return self._messages - def append_output(self, output: RequestOutput | list[Message]) -> None: - if isinstance(output, RequestOutput): - # append_output is called for each output token in streaming case, - # so we only want to add the prompt tokens once for each message. - if self.first_tok_of_message: - self._update_prefill_token_usage(output) - # Reset self.first_tok_of_message if needed: - # if the current token is the last one of the current message - # (finished=True), then the next token processed will mark the - # beginning of a new message - self.first_tok_of_message = output.finished - for tok in output.outputs[0].token_ids: - self.parser.process(tok) - self._update_decode_token_usage(output) - - # For streaming, update previous turn when message is complete - if output.finished: - self.all_turn_metrics.append(self.current_turn_metrics.copy()) - self.current_turn_metrics.reset() - # Check if the current token is part of reasoning content - self._update_num_reasoning_tokens() - self.last_tok = tok - if len(self._messages) - self.num_init_messages < len(self.parser.messages): - self._messages.extend( - self.parser.messages[len(self._messages) - self.num_init_messages :] - ) - else: - # Handle the case of tool output in direct message format - assert len(output) == 1, "Tool output should be a single message" - msg = output[0] - # Sometimes the recipient is not set for tool messages, - # so we set it to "assistant" - if msg.author.role == Role.TOOL and msg.recipient is None: - msg.recipient = "assistant" - toks = self.encoding.render(msg) - for tok in toks: - self.parser.process(tok) - self.last_tok = toks[-1] - # TODO: add tool_output messages to self._messages + def append_output(self, output: RequestOutput) -> None: + # append_output is called for each output token in streaming case, + # so we only want to add the prompt tokens once for each message. + if self.first_tok_of_message: + self._update_prefill_token_usage(output) + # Reset self.first_tok_of_message if needed: + # if the current token is the last one of the current message + # (finished=True), then the next token processed will mark the + # beginning of a new message + self.first_tok_of_message = output.finished + for tok in output.outputs[0].token_ids: + self.parser.process(tok) + self._update_decode_token_usage(output) + + # For streaming, update previous turn when message is complete + if output.finished: + self.all_turn_metrics.append(self.current_turn_metrics.copy()) + self.current_turn_metrics.reset() + # Check if the current token is part of reasoning content + self._update_num_reasoning_tokens() + self.last_tok = tok + if len(self._messages) - self.num_init_messages < len(self.parser.messages): + self._messages.extend( + self.parser.messages[len(self._messages) - self.num_init_messages :] + ) + + def append_tool_output(self, output: list[Message]) -> None: + # Handle the case of tool output in direct message format + assert len(output) == 1, "Tool output should be a single message" + msg = output[0] + # Sometimes the recipient is not set for tool messages, + # so we set it to "assistant" + if msg.author.role == Role.TOOL and msg.recipient is None: + msg.recipient = "assistant" + toks = self.encoding.render(msg) + for tok in toks: + self.parser.process(tok) + self.last_tok = toks[-1] + # TODO: add tool_output messages to self._messages def is_expecting_start(self) -> bool: return self.parser.state == StreamState.EXPECT_START diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 30b8499b08d5..1456727a3cdd 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1227,7 +1227,7 @@ async def _generate_with_builtin_tools( # Call the tool and update the context with the result. tool_output = await context.call_tool() - context.append_output(tool_output) + context.append_tool_output(tool_output) # TODO: uncomment this and enable tool output streaming # yield context