Skip to content

Commit c2774c4

Browse files
frankwang28gemini-code-assist[bot]chaunceyjiang
authored andcommitted
[Bugfix] Improve GLM4 MoE Reasoning Parser's is_reasoning_end Condition (vllm-project#25355)
Signed-off-by: frankwang28 <frank.wbb@hotmail.com> Signed-off-by: Frank Wang <41319051+frankwang28@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
1 parent 7063cfd commit c2774c4

File tree

2 files changed

+219
-3
lines changed

2 files changed

+219
-3
lines changed
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
from transformers import AutoTokenizer
6+
7+
from tests.reasoning.utils import run_reasoning_extraction
8+
from vllm.reasoning import ReasoningParser, ReasoningParserManager
9+
10+
parser_name = "glm45"
11+
start_token = "<think>"
12+
end_token = "</think>"
13+
14+
REASONING_MODEL_NAME = "zai-org/GLM-4.5"
15+
16+
17+
@pytest.fixture(scope="module")
18+
def glm45_tokenizer():
19+
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
20+
21+
22+
WITH_THINK = {
23+
"output": "<think>This is a reasoning section</think>This is the rest",
24+
"reasoning_content": "This is a reasoning section",
25+
"content": "This is the rest",
26+
"is_reasoning_end": True,
27+
}
28+
29+
WITH_THINK_STREAM = {
30+
"output": "<think>This is a reasoning section</think>This is the rest",
31+
"reasoning_content": "This is a reasoning section",
32+
"content": "This is the rest",
33+
"is_reasoning_end": True,
34+
}
35+
36+
WITHOUT_THINK = {
37+
"output": "This is the rest",
38+
"reasoning_content": None,
39+
"content": "This is the rest",
40+
"is_reasoning_end": False,
41+
}
42+
43+
WITHOUT_THINK_STREAM = {
44+
"output": "This is the rest",
45+
"reasoning_content": None,
46+
"content": "This is the rest",
47+
"is_reasoning_end": False,
48+
}
49+
50+
COMPLETE_REASONING = {
51+
"output": "<think>This is a reasoning section</think>",
52+
"reasoning_content": "This is a reasoning section",
53+
"content": None,
54+
"is_reasoning_end": True,
55+
}
56+
MULTILINE_REASONING = {
57+
"output":
58+
"<think>This is a reasoning\nsection</think>This is the rest\nThat",
59+
"reasoning_content": "This is a reasoning\nsection",
60+
"content": "This is the rest\nThat",
61+
"is_reasoning_end": True,
62+
}
63+
ONLY_OPEN_TAG = {
64+
"output": "<think>This is a reasoning section",
65+
"reasoning_content": None,
66+
"content": "<think>This is a reasoning section",
67+
"is_reasoning_end": False,
68+
}
69+
70+
ONLY_OPEN_TAG_STREAM = {
71+
"output": "<think>This is a reasoning section",
72+
"reasoning_content": "This is a reasoning section",
73+
"content": None,
74+
"is_reasoning_end": False,
75+
}
76+
77+
TEST_CASES = [
78+
pytest.param(
79+
False,
80+
WITH_THINK,
81+
id="with_think",
82+
),
83+
pytest.param(
84+
True,
85+
WITH_THINK_STREAM,
86+
id="with_think_stream",
87+
),
88+
pytest.param(
89+
False,
90+
WITHOUT_THINK,
91+
id="without_think",
92+
),
93+
pytest.param(
94+
True,
95+
WITHOUT_THINK_STREAM,
96+
id="without_think_stream",
97+
),
98+
pytest.param(
99+
False,
100+
COMPLETE_REASONING,
101+
id="complete_reasoning",
102+
),
103+
pytest.param(
104+
True,
105+
COMPLETE_REASONING,
106+
id="complete_reasoning_stream",
107+
),
108+
pytest.param(
109+
False,
110+
MULTILINE_REASONING,
111+
id="multiline_reasoning",
112+
),
113+
pytest.param(
114+
True,
115+
MULTILINE_REASONING,
116+
id="multiline_reasoning_stream",
117+
),
118+
pytest.param(
119+
False,
120+
ONLY_OPEN_TAG,
121+
id="only_open_tag",
122+
),
123+
pytest.param(
124+
True,
125+
ONLY_OPEN_TAG_STREAM,
126+
id="only_open_tag_stream",
127+
),
128+
]
129+
130+
STILL_REASONING_PROMPT = """[gMASK]<sop><|system|>
131+
You are a helpful assistant.<|user|>
132+
What is the capital of France?<|assistant|>
133+
<think>The user is asking for the capital of"""
134+
135+
DONE_REASONING_PROMPT = """[gMASK]<sop><|system|>
136+
You are a helpful assistant.<|user|>
137+
What is the capital of France?<|assistant|>
138+
<think>The user is asking for the capital of France.</think>
139+
The capital of France is Paris."""
140+
141+
MULTI_TURN_STILL_REASONING_PROMPT = """[gMASK]<sop><|system|>
142+
You are a helpful assistant.<|user|>
143+
What is the capital of France?<|assistant|>
144+
<think></think>
145+
The capital of France is Paris.<|user|>
146+
What about Chile?<|assistant|>
147+
<think>The user is asking for the capital of"""
148+
149+
MULTI_TURN_DONE_REASONING_PROMPT = """[gMASK]<sop><|system|>
150+
You are a helpful assistant.<|user|>
151+
What is the capital of France?<|assistant|>
152+
<think></think>
153+
The capital of France is Paris.<|user|>
154+
What about Chile?<|assistant|>
155+
<think>The user is asking for the capital of Chile.</think>
156+
The capital of Chile is Santiago."""
157+
158+
REASONING_END_TEST_CASES = [
159+
pytest.param(STILL_REASONING_PROMPT, False, id="still_reasoning"),
160+
pytest.param(DONE_REASONING_PROMPT, True, id="done_reasoning"),
161+
pytest.param(MULTI_TURN_STILL_REASONING_PROMPT,
162+
False,
163+
id="multi_turn_still_reasoning"),
164+
pytest.param(MULTI_TURN_DONE_REASONING_PROMPT,
165+
True,
166+
id="multi_turn_done_reasoning")
167+
]
168+
169+
170+
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
171+
def test_reasoning(
172+
streaming: bool,
173+
param_dict: dict,
174+
glm45_tokenizer,
175+
):
176+
output = glm45_tokenizer.tokenize(param_dict["output"])
177+
output_tokens: list[str] = [
178+
glm45_tokenizer.convert_tokens_to_string([token]) for token in output
179+
]
180+
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
181+
parser_name)(glm45_tokenizer)
182+
183+
reasoning, content = run_reasoning_extraction(parser,
184+
output_tokens,
185+
streaming=streaming)
186+
187+
assert reasoning == param_dict["reasoning_content"]
188+
assert content == param_dict["content"]
189+
190+
output_ids = glm45_tokenizer.convert_tokens_to_ids(output)
191+
is_reasoning_end = parser.is_reasoning_end(output_ids)
192+
assert is_reasoning_end == param_dict["is_reasoning_end"]
193+
194+
195+
@pytest.mark.parametrize("prompt, is_reasoning_end", REASONING_END_TEST_CASES)
196+
def test_is_reasoning_end_full_prompt(prompt: str, is_reasoning_end: bool,
197+
glm45_tokenizer):
198+
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
199+
parser_name)(glm45_tokenizer)
200+
tokens = glm45_tokenizer.tokenize(prompt)
201+
token_ids = glm45_tokenizer.convert_tokens_to_ids(tokens)
202+
check_is_reasoning_end = parser.is_reasoning_end(token_ids)
203+
assert check_is_reasoning_end == is_reasoning_end

vllm/reasoning/glm4_moe_reasoning_parser.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
3030
super().__init__(tokenizer, *args, **kwargs)
3131
self.think_start_token = "<think>"
3232
self.think_end_token = "</think>"
33+
self.assistant_token = "<|assistant|>"
3334

3435
if not self.model_tokenizer:
3536
raise ValueError(
@@ -38,14 +39,26 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
3839

3940
self.think_start_token_id = self.vocab.get(self.think_start_token)
4041
self.think_end_token_id = self.vocab.get(self.think_end_token)
42+
self.assistant_token_id = self.vocab.get(self.assistant_token)
4143
if (self.think_start_token_id is None
42-
or self.think_end_token_id is None):
44+
or self.think_end_token_id is None
45+
or self.assistant_token_id is None):
4346
raise RuntimeError(
4447
"Glm4MoeModel reasoning parser could not locate "
45-
"think start/end tokens in the tokenizer!")
48+
"think start/end or assistant tokens in the tokenizer!")
4649

4750
def is_reasoning_end(self, input_ids: list[int]) -> bool:
48-
return self.think_end_token_id in input_ids
51+
"""
52+
GLM's chat template has <think></think> tokens after every
53+
<|assistant|> token. Thus, we need to check if </think> is
54+
after the most recent <|assistant|> token (if present).
55+
"""
56+
for token_id in input_ids[::-1]:
57+
if token_id == self.think_end_token_id:
58+
return True
59+
elif token_id == self.assistant_token_id:
60+
return False
61+
return False
4962

5063
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
5164
"""

0 commit comments

Comments
 (0)