|
| 1 | +"""Tests for the base tool parser.""" |
| 2 | + |
1 | 3 | import unittest |
2 | 4 |
|
3 | | -from mlx_openai_server.handler.parser.base import BaseToolParser |
| 5 | +from mlx_openai_server.handler.parser.base import BaseToolParser # type: ignore[import-not-found] |
4 | 6 |
|
5 | 7 |
|
6 | 8 | class TestBaseToolParser(unittest.TestCase): |
7 | | - def setUp(self): |
| 9 | + """Test cases for BaseToolParser.""" |
| 10 | + |
| 11 | + def setUp(self) -> None: |
| 12 | + """Set up test cases.""" |
8 | 13 | self.test_cases = [ |
9 | 14 | { |
10 | 15 | "name": "simple function call", |
11 | | - "chunks": """#<tool_call># |
12 | | -#{"#name#":# "#get#_weather#",# "#arguments#":# {"#city#":# "#H#ue#"}} |
13 | | -#</tool_call># |
14 | | -#<tool_call># |
15 | | -#{"#name#":# "#get#_weather#",# "#arguments#":# {"#city#":# "#Sy#dney#"}} |
16 | | -#</tool_call>##""".split("#"), |
| 16 | + "chunks": [ |
| 17 | + "Some text before <tool_call>\n", |
| 18 | + '{"name": "get_weather", "arguments": {"city": "Hue"}}\n', |
| 19 | + "</tool_call>\nMore text after\n", |
| 20 | + "<tool_call>\n", |
| 21 | + '{"name": "get_weather", "arguments": {"city": "Sydney"}}\n', |
| 22 | + "</tool_call>\nFinal text", |
| 23 | + ], |
17 | 24 | "expected_outputs": [ |
18 | | - {"name": "get_weather", "arguments": ""}, |
19 | | - {"name": None, "arguments": ' {"'}, |
20 | | - {"name": None, "arguments": "city"}, |
21 | | - {"name": None, "arguments": '":'}, |
22 | | - {"name": None, "arguments": ' "'}, |
23 | | - {"name": None, "arguments": "H"}, |
24 | | - {"name": None, "arguments": "ue"}, |
25 | | - {"name": None, "arguments": '"}'}, |
26 | | - "\n", |
27 | | - {"name": "get_weather", "arguments": ""}, |
28 | | - {"name": None, "arguments": ' {"'}, |
29 | | - {"name": None, "arguments": "city"}, |
30 | | - {"name": None, "arguments": '":'}, |
31 | | - {"name": None, "arguments": ' "'}, |
32 | | - {"name": None, "arguments": "Sy"}, |
33 | | - {"name": None, "arguments": "dney"}, |
34 | | - {"name": None, "arguments": '"}'}, |
| 25 | + "Some text before ", # Text before tool call |
| 26 | + {"name": "get_weather", "arguments": '{"city": "Hue"}'}, # Complete tool call |
| 27 | + "", # Empty string from chunk with opening tag |
| 28 | + { |
| 29 | + "name": "get_weather", |
| 30 | + "arguments": '{"city": "Sydney"}', |
| 31 | + }, # Complete tool call |
35 | 32 | ], |
36 | 33 | }, |
37 | 34 | { |
38 | | - "name": "code function call", |
39 | | - "chunks": r"""<tool_call>@@ |
40 | | -@@{"@@name@@":@@ "@@python@@",@@ "@@arguments@@":@@ {"@@code@@":@@ "@@def@@ calculator@@(a@@,@@ b@@,@@ operation@@):\@@n@@ @@ if@@ operation@@ ==@@ '@@add@@'\@@n@@ @@ return@@ a@@ +@@ b@@\n@@ @@ if@@ operation@@ ==@@ '@@subtract@@'\@@n@@ @@ return@@ a@@ -@@ b@@\n@@ @@ if@@ operation@@ ==@@ '@@multiply@@'\@@n@@ @@ return@@ a@@ *@@ b@@\n@@ @@ if@@ operation@@ ==@@ '@@divide@@'\@@n@@ @@ return@@ a@@ /@@ b@@\n@@ @@ return@@ '@@Invalid@@ operation@@'@@"}} |
41 | | -@@</tool_call>@@@@""".split("@@"), |
| 35 | + "name": "streaming function call", |
| 36 | + "chunks": [ |
| 37 | + "Start <tool_call>\n", |
| 38 | + '{"name": "python", "arguments": ', |
| 39 | + '{"code": "print(\'hello\')"}}\n', |
| 40 | + "</tool_call>\nEnd", |
| 41 | + ], |
42 | 42 | "expected_outputs": [ |
43 | | - {"name": "python", "arguments": ""}, |
44 | | - {"name": None, "arguments": ' {"'}, |
45 | | - {"name": None, "arguments": "code"}, |
46 | | - {"name": None, "arguments": '":'}, |
47 | | - {"name": None, "arguments": ' "'}, |
48 | | - {"name": None, "arguments": "def"}, |
49 | | - {"name": None, "arguments": " calculator"}, |
50 | | - {"name": None, "arguments": "(a"}, |
51 | | - {"name": None, "arguments": ","}, |
52 | | - {"name": None, "arguments": " b"}, |
53 | | - {"name": None, "arguments": ","}, |
54 | | - {"name": None, "arguments": " operation"}, |
55 | | - {"name": None, "arguments": "):\\"}, |
56 | | - {"name": None, "arguments": "n"}, |
57 | | - {"name": None, "arguments": " "}, |
58 | | - {"name": None, "arguments": " if"}, |
59 | | - {"name": None, "arguments": " operation"}, |
60 | | - {"name": None, "arguments": " =="}, |
61 | | - {"name": None, "arguments": " '"}, |
62 | | - {"name": None, "arguments": "add"}, |
63 | | - {"name": None, "arguments": "'\\"}, |
64 | | - {"name": None, "arguments": "n"}, |
65 | | - {"name": None, "arguments": " "}, |
66 | | - {"name": None, "arguments": " return"}, |
67 | | - {"name": None, "arguments": " a"}, |
68 | | - {"name": None, "arguments": " +"}, |
69 | | - {"name": None, "arguments": " b"}, |
70 | | - {"name": None, "arguments": "\\n"}, |
71 | | - {"name": None, "arguments": " "}, |
72 | | - {"name": None, "arguments": " if"}, |
73 | | - {"name": None, "arguments": " operation"}, |
74 | | - {"name": None, "arguments": " =="}, |
75 | | - {"name": None, "arguments": " '"}, |
76 | | - {"name": None, "arguments": "subtract"}, |
77 | | - {"name": None, "arguments": "'\\"}, |
78 | | - {"name": None, "arguments": "n"}, |
79 | | - {"name": None, "arguments": " "}, |
80 | | - {"name": None, "arguments": " return"}, |
81 | | - {"name": None, "arguments": " a"}, |
82 | | - {"name": None, "arguments": " -"}, |
83 | | - {"name": None, "arguments": " b"}, |
84 | | - {"name": None, "arguments": "\\n"}, |
85 | | - {"name": None, "arguments": " "}, |
86 | | - {"name": None, "arguments": " if"}, |
87 | | - {"name": None, "arguments": " operation"}, |
88 | | - {"name": None, "arguments": " =="}, |
89 | | - {"name": None, "arguments": " '"}, |
90 | | - {"name": None, "arguments": "multiply"}, |
91 | | - {"name": None, "arguments": "'\\"}, |
92 | | - {"name": None, "arguments": "n"}, |
93 | | - {"name": None, "arguments": " "}, |
94 | | - {"name": None, "arguments": " return"}, |
95 | | - {"name": None, "arguments": " a"}, |
96 | | - {"name": None, "arguments": " *"}, |
97 | | - {"name": None, "arguments": " b"}, |
98 | | - {"name": None, "arguments": "\\n"}, |
99 | | - {"name": None, "arguments": " "}, |
100 | | - {"name": None, "arguments": " if"}, |
101 | | - {"name": None, "arguments": " operation"}, |
102 | | - {"name": None, "arguments": " =="}, |
103 | | - {"name": None, "arguments": " '"}, |
104 | | - {"name": None, "arguments": "divide"}, |
105 | | - {"name": None, "arguments": "'\\"}, |
106 | | - {"name": None, "arguments": "n"}, |
107 | | - {"name": None, "arguments": " "}, |
108 | | - {"name": None, "arguments": " return"}, |
109 | | - {"name": None, "arguments": " a"}, |
110 | | - {"name": None, "arguments": " /"}, |
111 | | - {"name": None, "arguments": " b"}, |
112 | | - {"name": None, "arguments": "\\n"}, |
113 | | - {"name": None, "arguments": " "}, |
114 | | - {"name": None, "arguments": " return"}, |
115 | | - {"name": None, "arguments": " '"}, |
116 | | - {"name": None, "arguments": "Invalid"}, |
117 | | - {"name": None, "arguments": " operation"}, |
118 | | - {"name": None, "arguments": "'"}, |
119 | | - {"name": None, "arguments": '"}'}, |
| 43 | + "Start ", # Text before tool call |
| 44 | + { |
| 45 | + "name": "python", |
| 46 | + "arguments": '{"code": "print(\'hello\')"}', |
| 47 | + }, # Complete tool call |
120 | 48 | ], |
121 | 49 | }, |
122 | 50 | ] |
123 | 51 |
|
124 | | - def test_parse_stream(self): |
| 52 | + def test_parse_stream(self) -> None: |
| 53 | + """Test parsing stream.""" |
125 | 54 | for test_case in self.test_cases: |
126 | 55 | with self.subTest(msg=test_case["name"]): |
127 | 56 | parser = BaseToolParser("<tool_call>", "</tool_call>") |
128 | 57 | outputs = [] |
129 | 58 |
|
130 | 59 | for chunk in test_case["chunks"]: |
131 | | - result = parser.parse_stream(chunk) |
132 | | - if result: |
133 | | - outputs.append(result) |
| 60 | + parsed, _complete = parser.parse_stream(chunk) |
| 61 | + if parsed is not None: |
| 62 | + if isinstance(parsed, list): |
| 63 | + outputs.extend(parsed) |
| 64 | + else: |
| 65 | + outputs.append(parsed) |
134 | 66 |
|
135 | | - self.assertEqual( |
136 | | - len(outputs), |
137 | | - len(test_case["expected_outputs"]), |
138 | | - f"Expected {len(test_case['expected_outputs'])} outputs, got {len(outputs)}", |
139 | | - ) |
140 | | - |
141 | | - for i, (output, expected) in enumerate(zip(outputs, test_case["expected_outputs"])): |
142 | | - self.assertEqual( |
143 | | - output, expected, f"Chunk {i}: Expected {expected}, got {output}" |
144 | | - ) |
| 67 | + # Get any remaining content |
| 68 | + remaining, _complete = parser.parse_stream(None) |
| 69 | + if remaining is not None: |
| 70 | + outputs.append(remaining) |
145 | 71 |
|
| 72 | + assert len(outputs) == len(test_case["expected_outputs"]), ( |
| 73 | + f"Expected {len(test_case['expected_outputs'])} outputs, got {len(outputs)}" |
| 74 | + ) |
146 | 75 |
|
147 | | -if __name__ == "__main__": |
148 | | - unittest.main() |
| 76 | + for i, (output, expected) in enumerate( |
| 77 | + zip(outputs, test_case["expected_outputs"], strict=True) |
| 78 | + ): |
| 79 | + assert output == expected, f"Chunk {i}: Expected {expected}, got {output}" |
0 commit comments