Skip to content

Commit 5b38f40

Browse files
committed
Fix mock server type errors
Signed-off-by: Jared O'Connell <joconnel@redhat.com>
1 parent a47c3c3 commit 5b38f40

File tree

2 files changed

+16
-21
lines changed

2 files changed

+16
-21
lines changed

src/guidellm/mock_server/handlers/chat_completions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ async def _handle_non_stream(self, req: ChatCompletionsRequest) -> HTTPResponse:
136136

137137
# Token counts
138138
prompt_text = self.tokenizer.apply_chat_template(req.messages)
139-
prompt_tokens = len(self.tokenizer(prompt_text))
139+
prompt_tokens = len(self.tokenizer(prompt_text)) # type: ignore[arg-type]
140140
max_tokens = req.max_completion_tokens or req.max_tokens or math.inf
141141
completion_tokens_count = min(
142142
sample_number(self.config.output_tokens, self.config.output_tokens_std),
@@ -197,7 +197,7 @@ async def generate_stream(stream_response):
197197

198198
# Token counts
199199
prompt_text = self.tokenizer.apply_chat_template(req.messages)
200-
prompt_tokens = len(self.tokenizer(prompt_text))
200+
prompt_tokens = len(self.tokenizer(prompt_text)) # type: ignore[arg-type]
201201
max_tokens = req.max_completion_tokens or req.max_tokens or math.inf
202202
completion_tokens_count = int(
203203
min(

src/guidellm/mock_server/utils.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,15 @@ def __call__(self, text: str | list[str], **kwargs) -> list[int]: # noqa: ARG00
5858
return self.convert_tokens_to_ids(tokens)
5959
elif isinstance(text, list):
6060
# Handle batch processing
61-
return [self.__call__(t) for t in text]
61+
result = []
62+
for t in text:
63+
result.extend(self.__call__(t))
64+
return result
6265
else:
6366
msg = f"text input must be of type `str` or `list[str]`, got {type(text)}"
6467
raise ValueError(msg)
6568

66-
def tokenize(self, text: TextInput, **_kwargs) -> list[str]:
69+
def tokenize(self, text: TextInput, **_kwargs) -> list[str]: # type: ignore[override]
6770
"""
6871
Tokenize input text into a list of token strings.
6972
@@ -76,7 +79,7 @@ def tokenize(self, text: TextInput, **_kwargs) -> list[str]:
7679
# Split text into tokens: words, spaces, and punctuation
7780
return re.findall(r"\w+|[^\w\s]|\s+", text)
7881

79-
def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
82+
def convert_tokens_to_ids(self, tokens: str | list[str]) -> list[int]:
8083
"""
8184
Convert token strings to numeric token IDs.
8285
@@ -87,12 +90,12 @@ def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
8790
:return: Single token ID or list of token IDs
8891
"""
8992
if isinstance(tokens, str):
90-
return hash(tokens) % self.VocabSize
93+
return [hash(tokens) % self.VocabSize]
9194
return [hash(token) % self.VocabSize for token in tokens]
9295

93-
def convert_ids_to_tokens(
94-
self, ids: int | list[int], _skip_special_tokens: bool = False
95-
) -> str | list[str]:
96+
def convert_ids_to_tokens( # type: ignore[override]
97+
self, ids: list[int], _skip_special_tokens: bool = False
98+
) -> list[str]:
9699
"""
97100
Convert numeric token IDs back to token strings.
98101
@@ -102,17 +105,9 @@ def convert_ids_to_tokens(
102105
:param ids: Single token ID or list of token IDs to convert
103106
:return: Single token string or list of token strings
104107
"""
105-
if not ids and not isinstance(ids, list):
106-
return ""
107-
elif not ids:
108+
if not ids:
108109
return [""]
109110

110-
if isinstance(ids, int):
111-
fake = Faker()
112-
fake.seed_instance(ids % self.VocabSize)
113-
114-
return fake.word()
115-
116111
fake = Faker()
117112
fake.seed_instance(sum(ids) % self.VocabSize)
118113

@@ -162,7 +157,7 @@ def _add_tokens(
162157
"""
163158
return 0
164159

165-
def apply_chat_template(
160+
def apply_chat_template( # type: ignore[override]
166161
self,
167162
conversation: list,
168163
tokenize: bool = False, # Changed default to False to match transformers
@@ -193,7 +188,7 @@ def apply_chat_template(
193188
return self.convert_tokens_to_ids(self.tokenize(formatted_text))
194189
return formatted_text
195190

196-
def decode(
191+
def decode( # type: ignore[override]
197192
self,
198193
token_ids: list[int],
199194
skip_special_tokens: bool = True,
@@ -255,7 +250,7 @@ def create_fake_tokens_str(
255250
fake = Faker()
256251
fake.seed_instance(seed)
257252

258-
tokens = []
253+
tokens: list[str] = []
259254

260255
while len(tokens) < num_tokens:
261256
text = fake.text(

0 commit comments

Comments
 (0)