Skip to content

Commit 3f6cfd6

Browse files
authored
Modernize entrypoints module with type hints and use generic types (#1965)
SUMMARY: This is part of #1927 Modernize type annotations using | operator and built-in generics in the `entrypoints module` as part of codebase modernization effort. I also fixed 2 typos as part of this PR. TEST PLAN: - make style - make quality - make tests Notes: Happy to address any comments! Also please let me know if you dont want typos to be fixed as part of this PR. Thank you!
1 parent ee755a4 commit 3f6cfd6

File tree

2 files changed

+34
-32
lines changed

2 files changed

+34
-32
lines changed

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
with various pipeline configurations for efficient model optimization.
88
"""
99

10+
from __future__ import annotations
11+
1012
import os
1113
from datetime import datetime
1214
from pathlib import Path
13-
from typing import TYPE_CHECKING, Dict, List, Optional, Union
15+
from typing import TYPE_CHECKING
1416

1517
from loguru import logger
1618
from torch.utils.data import DataLoader
@@ -36,7 +38,7 @@ class Oneshot:
3638
This class handles the entire lifecycle of one-shot calibration, including
3739
preprocessing (model and tokenizer/processor initialization), model optimization
3840
(quantization or sparsification), and postprocessing (saving outputs). The
39-
intructions for model optimization can be specified by using a recipe.
41+
instructions for model optimization can be specified by using a recipe.
4042
4143
- **Input Keyword Arguments:**
4244
`kwargs` are parsed into:
@@ -99,7 +101,7 @@ class Oneshot:
99101

100102
def __init__(
101103
self,
102-
log_dir: Optional[str] = None,
104+
log_dir: str | None = None,
103105
**kwargs,
104106
):
105107
"""
@@ -179,8 +181,8 @@ def __call__(self):
179181

180182
def apply_recipe_modifiers(
181183
self,
182-
calibration_dataloader: Optional[DataLoader],
183-
recipe_stage: Optional[str] = None,
184+
calibration_dataloader: DataLoader | None,
185+
recipe_stage: str | None = None,
184186
):
185187
"""
186188
Applies recipe modifiers to the model during the lifecycle.
@@ -198,7 +200,7 @@ def apply_recipe_modifiers(
198200
session = active_session()
199201
session.reset()
200202

201-
# (Helen INFERENG-661): validate recipe modifiers before intialization
203+
# (Helen INFERENG-661): validate recipe modifiers before initialization
202204
session.initialize(
203205
model=self.model,
204206
start=-1,
@@ -226,27 +228,27 @@ def apply_recipe_modifiers(
226228

227229
def oneshot(
228230
# Model arguments
229-
model: Union[str, PreTrainedModel],
230-
distill_teacher: Optional[str] = None,
231-
config_name: Optional[str] = None,
232-
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
233-
processor: Optional[Union[str, ProcessorMixin]] = None,
231+
model: str | PreTrainedModel,
232+
distill_teacher: str | None = None,
233+
config_name: str | None = None,
234+
tokenizer: str | PreTrainedTokenizerBase | None = None,
235+
processor: str | ProcessorMixin | None = None,
234236
use_auth_token: bool = False,
235237
precision: str = "auto",
236238
tie_word_embeddings: bool = True,
237239
trust_remote_code_model: bool = False,
238240
save_compressed: bool = True,
239241
model_revision: str = "main",
240242
# Recipe arguments
241-
recipe: Optional[Union[str, List[str]]] = None,
242-
recipe_args: Optional[List[str]] = None,
243+
recipe: str | list[str] | None = None,
244+
recipe_args: list[str] | None = None,
243245
clear_sparse_session: bool = False,
244-
stage: Optional[str] = None,
246+
stage: str | None = None,
245247
# Dataset arguments
246-
dataset: Optional[Union[str, "Dataset", "DatasetDict"]] = None,
247-
dataset_config_name: Optional[str] = None,
248-
dataset_path: Optional[str] = None,
249-
splits: Optional[Union[str, List, Dict]] = None,
248+
dataset: str | Dataset | DatasetDict | None = None,
249+
dataset_config_name: str | None = None,
250+
dataset_path: str | None = None,
251+
splits: str | list[str] | dict[str, str] | None = None,
250252
num_calibration_samples: int = 512,
251253
shuffle_calibration_samples: bool = True,
252254
max_seq_length: int = 384,
@@ -255,13 +257,13 @@ def oneshot(
255257
concatenate_data: bool = False,
256258
streaming: bool = False,
257259
overwrite_cache: bool = False,
258-
preprocessing_num_workers: Optional[int] = None,
259-
min_tokens_per_module: Optional[float] = None,
260+
preprocessing_num_workers: int | None = None,
261+
min_tokens_per_module: float | None = None,
260262
moe_calibrate_all_experts: bool = True,
261263
quantization_aware_calibration: bool = True,
262264
# Miscellaneous arguments
263-
output_dir: Optional[str] = None,
264-
log_dir: Optional[str] = None,
265+
output_dir: str | None = None,
266+
log_dir: str | None = None,
265267
**kwargs,
266268
) -> PreTrainedModel:
267269
"""
@@ -290,7 +292,8 @@ def oneshot(
290292
tag, or commit id).
291293
292294
# Recipe arguments
293-
:param recipe: Path to a LLM Compressor sparsification recipe.
295+
:param recipe: Path to a LLM Compressor recipe, or a list of paths
296+
to multiple LLM Compressor recipes.
294297
:param recipe_args: List of recipe arguments to evaluate, in the
295298
format "key1=value1", "key2=value2".
296299
:param clear_sparse_session: Whether to clear CompressionSession/

src/llmcompressor/entrypoints/utils.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import inspect
1111
import os
1212
from pathlib import PosixPath
13-
from typing import Optional, Tuple
1413

1514
from compressed_tensors.utils import remove_dispatch
1615
from loguru import logger
@@ -47,7 +46,7 @@
4746
def pre_process(
4847
model_args: ModelArguments,
4948
dataset_args: DatasetArguments,
50-
output_dir: Optional[str],
49+
output_dir: str | None,
5150
):
5251
"""
5352
Prepares the model and tokenizer/processor for calibration.
@@ -102,9 +101,9 @@ def pre_process(
102101

103102

104103
def post_process(
105-
model_args: Optional["ModelArguments"] = None,
106-
recipe_args: Optional["RecipeArguments"] = None,
107-
output_dir: Optional[str] = None,
104+
model_args: ModelArguments | None = None,
105+
recipe_args: RecipeArguments | None = None,
106+
output_dir: str | None = None,
108107
):
109108
"""
110109
Saves the model and tokenizer/processor to the output directory if model_args,
@@ -151,8 +150,8 @@ def post_process(
151150

152151
def initialize_model_from_path(
153152
model_args: ModelArguments,
154-
training_args: Optional[TrainingArguments] = None,
155-
) -> Tuple[PreTrainedModel, Optional[PreTrainedModel]]:
153+
training_args: TrainingArguments | None = None,
154+
) -> tuple[PreTrainedModel, PreTrainedModel | None]:
156155
# Load pretrained model
157156
# The .from_pretrained methods guarantee that only one local process can
158157
# concurrently download model & vocab.
@@ -240,7 +239,7 @@ def initialize_model_from_path(
240239
def initialize_processor_from_path(
241240
model_args: ModelArguments,
242241
model: PreTrainedModel,
243-
teacher: Optional[PreTrainedModel] = None,
242+
teacher: PreTrainedModel | None = None,
244243
) -> Processor:
245244
processor_src = model_args.processor or get_processor_name_from_model(
246245
model, teacher
@@ -279,7 +278,7 @@ def initialize_processor_from_path(
279278
return processor
280279

281280

282-
def get_processor_name_from_model(student: Module, teacher: Optional[Module]) -> str:
281+
def get_processor_name_from_model(student: Module, teacher: Module | None) -> str:
283282
"""
284283
Get a processor/tokenizer source used for both student and teacher, assuming
285284
that they could be shared

0 commit comments

Comments
 (0)