Skip to content

Commit 1890321

Browse files
authored
[Bugfix] Fix and add tests for GptOss reasoning parser (#28000)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
1 parent d0ceb38 commit 1890321

File tree

2 files changed

+151
-7
lines changed

2 files changed

+151
-7
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
from transformers import AutoTokenizer
6+
7+
from vllm.reasoning import ReasoningParser
8+
from vllm.reasoning.gptoss_reasoning_parser import GptOssReasoningParser
9+
10+
REASONING_MODEL_NAME = "openai/gpt-oss-120b"
11+
12+
13+
@pytest.fixture(scope="module")
14+
def gpt_oss_tokenizer():
15+
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
16+
17+
18+
USER_MESSAGE_START = "<|start|>user<|message|>"
19+
REASONING_SECTION_START = "<|end|><|start|>assistant<|channel|>analysis<|message|>"
20+
ASSISTANT_CONTENT_START_PREFIX = "<|end|><|start|>assistant<|channel|>final"
21+
ASSISTANT_CONTENT_START_SUFFIX = "<|message|>"
22+
ASSISTANT_CONTENT_START = (
23+
ASSISTANT_CONTENT_START_PREFIX + ASSISTANT_CONTENT_START_SUFFIX
24+
)
25+
26+
BASIC_CONTENT = {
27+
"output": REASONING_SECTION_START
28+
+ "This is reasoning"
29+
+ ASSISTANT_CONTENT_START
30+
+ "This is the rest",
31+
"is_reasoning_end": True,
32+
}
33+
34+
BASIC_REASONING_ONLY = {
35+
"output": REASONING_SECTION_START + "This is reasoning" + "<|end|>",
36+
"is_reasoning_end": False,
37+
}
38+
BASIC_NO_REASONING_NO_ASSISTANT = {
39+
"output": USER_MESSAGE_START + "This is a user message",
40+
"is_reasoning_end": False,
41+
}
42+
43+
# Edge-case where the model omits the assistant tag entirely.
44+
BASIC_NO_REASONING_ASSISTANT = {
45+
"output": USER_MESSAGE_START + "This is a user message<|end|><|channel|>final",
46+
"is_reasoning_end": True,
47+
}
48+
49+
COMPLEX_CONTENT_INCOMPLETE_PREFIX_ONLY = {
50+
"output": REASONING_SECTION_START
51+
+ "This is reasoning"
52+
+ ASSISTANT_CONTENT_START_PREFIX,
53+
"is_reasoning_end": False,
54+
}
55+
56+
COMPLEX_CONTENT_SUFFIX_ONLY = {
57+
"output": REASONING_SECTION_START
58+
+ "This is reasoning"
59+
+ ASSISTANT_CONTENT_START_SUFFIX,
60+
"is_reasoning_end": False,
61+
}
62+
63+
COMPLEX_CONTENT_1_NO_SUFFIX = {
64+
"output": REASONING_SECTION_START
65+
+ "This is reasoning"
66+
+ ASSISTANT_CONTENT_START_PREFIX
67+
+ "<|constrain|> JSON ",
68+
"is_reasoning_end": False,
69+
}
70+
71+
COMPLEX_CONTENT_1 = {
72+
"output": REASONING_SECTION_START
73+
+ "This is reasoning"
74+
+ ASSISTANT_CONTENT_START_PREFIX
75+
+ "<|constrain|> JSON "
76+
+ ASSISTANT_CONTENT_START_SUFFIX,
77+
"is_reasoning_end": True,
78+
}
79+
80+
COMPLEX_CONTENT_1_WITH_CONTENT = {
81+
"output": REASONING_SECTION_START
82+
+ "This is reasoning"
83+
+ ASSISTANT_CONTENT_START_PREFIX
84+
+ "<|constrain|> JSON "
85+
+ ASSISTANT_CONTENT_START_SUFFIX
86+
+ "This is the rest",
87+
"is_reasoning_end": True,
88+
}
89+
90+
COMPLEX_CONTENT_2 = {
91+
"output": REASONING_SECTION_START
92+
+ "This is reasoning"
93+
+ ASSISTANT_CONTENT_START_PREFIX
94+
+ "<|constrain|>ReplyAction "
95+
+ ASSISTANT_CONTENT_START_SUFFIX
96+
+ "This is the rest",
97+
"is_reasoning_end": True,
98+
}
99+
100+
TEST_CASES = [
101+
BASIC_CONTENT,
102+
BASIC_REASONING_ONLY,
103+
COMPLEX_CONTENT_INCOMPLETE_PREFIX_ONLY,
104+
COMPLEX_CONTENT_SUFFIX_ONLY,
105+
COMPLEX_CONTENT_1_NO_SUFFIX,
106+
COMPLEX_CONTENT_1,
107+
COMPLEX_CONTENT_1_WITH_CONTENT,
108+
COMPLEX_CONTENT_2,
109+
]
110+
111+
112+
@pytest.mark.parametrize(
113+
"output, is_reasoning_end",
114+
[(t["output"], t["is_reasoning_end"]) for t in TEST_CASES],
115+
)
116+
def test_gptoss_is_reasoning_end(
117+
output,
118+
is_reasoning_end,
119+
gpt_oss_tokenizer,
120+
):
121+
output = gpt_oss_tokenizer.tokenize(output)
122+
parser: ReasoningParser = GptOssReasoningParser(gpt_oss_tokenizer)
123+
124+
# Test is_reasoning_end
125+
output_ids = gpt_oss_tokenizer.convert_tokens_to_ids(output)
126+
actual_is_reasoning_end = parser.is_reasoning_end(output_ids)
127+
assert is_reasoning_end == actual_is_reasoning_end

vllm/reasoning/gptoss_reasoning_parser.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,35 @@ class GptOssReasoningParser(ReasoningParser):
6767

6868
def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
6969
super().__init__(tokenizer, *args, **kwargs)
70-
self.reasoning_end_token_ids = self.model_tokenizer.encode(
71-
"<|start|>assistant<|channel|>final<|message|>"
70+
# The model can output some special tokens between "final" and "<|message|>"
71+
# So we need to look for both sequences to determine the end of reasoning.
72+
self.reasoning_end_token_ids_prefix = self.model_tokenizer.encode(
73+
"<|channel|>final"
7274
)
75+
self.reasoning_end_token_ids_suffix = self.model_tokenizer.encode("<|message|>")
76+
self.reasoning_max_num_between_tokens = 20
7377

7478
def is_reasoning_end(self, input_ids: list[int]) -> bool:
75-
end_token_ids = self.reasoning_end_token_ids
76-
assert len(end_token_ids) > 0, "reasoning_end_token_ids is empty"
79+
end_token_ids_prefix = self.reasoning_end_token_ids_prefix
80+
end_token_ids_suffix = self.reasoning_end_token_ids_suffix
81+
assert len(end_token_ids_prefix) > 0, "reasoning_end_token_ids_prefix is empty"
82+
assert len(end_token_ids_suffix) > 0, "reasoning_end_token_ids_suffix is empty"
7783
# Check if the end sequence is present in the input_ids.
7884
# We search from the end of input_ids to find the last match.
79-
for i in range(len(input_ids) - len(end_token_ids), -1, -1):
80-
if input_ids[i : i + len(end_token_ids)] == end_token_ids:
81-
return True
85+
for i in range(len(input_ids) - len(end_token_ids_prefix), -1, -1):
86+
if input_ids[i : i + len(end_token_ids_prefix)] == end_token_ids_prefix:
87+
# We have found the prefix, now we look for the suffix after the prefix.
88+
suffix_start = i + len(end_token_ids_prefix)
89+
for j in range(
90+
suffix_start, len(input_ids) - len(end_token_ids_suffix) + 1
91+
):
92+
if j - suffix_start >= self.reasoning_max_num_between_tokens:
93+
break
94+
if (
95+
input_ids[j : j + len(end_token_ids_suffix)]
96+
== end_token_ids_suffix
97+
):
98+
return True
8299
return False
83100

84101
def extract_content_ids(self, input_ids: list[int]) -> list[int]:

0 commit comments

Comments
 (0)