77with various pipeline configurations for efficient model optimization.
88"""
99
10+ from __future__ import annotations
11+
1012import os
1113from datetime import datetime
1214from pathlib import Path
13- from typing import TYPE_CHECKING , Dict , List , Optional , Union
15+ from typing import TYPE_CHECKING
1416
1517from loguru import logger
1618from torch .utils .data import DataLoader
@@ -36,7 +38,7 @@ class Oneshot:
3638 This class handles the entire lifecycle of one-shot calibration, including
3739 preprocessing (model and tokenizer/processor initialization), model optimization
3840 (quantization or sparsification), and postprocessing (saving outputs). The
39- intructions for model optimization can be specified by using a recipe.
41+ instructions for model optimization can be specified by using a recipe.
4042
4143 - **Input Keyword Arguments:**
4244 `kwargs` are parsed into:
@@ -99,7 +101,7 @@ class Oneshot:
99101
100102 def __init__ (
101103 self ,
102- log_dir : Optional [ str ] = None ,
104+ log_dir : str | None = None ,
103105 ** kwargs ,
104106 ):
105107 """
@@ -179,8 +181,8 @@ def __call__(self):
179181
180182 def apply_recipe_modifiers (
181183 self ,
182- calibration_dataloader : Optional [ DataLoader ] ,
183- recipe_stage : Optional [ str ] = None ,
184+ calibration_dataloader : DataLoader | None ,
185+ recipe_stage : str | None = None ,
184186 ):
185187 """
186188 Applies recipe modifiers to the model during the lifecycle.
@@ -198,7 +200,7 @@ def apply_recipe_modifiers(
198200 session = active_session ()
199201 session .reset ()
200202
201- # (Helen INFERENG-661): validate recipe modifiers before intialization
203+ # (Helen INFERENG-661): validate recipe modifiers before initialization
202204 session .initialize (
203205 model = self .model ,
204206 start = - 1 ,
@@ -226,27 +228,27 @@ def apply_recipe_modifiers(
226228
227229def oneshot (
228230 # Model arguments
229- model : Union [ str , PreTrainedModel ] ,
230- distill_teacher : Optional [ str ] = None ,
231- config_name : Optional [ str ] = None ,
232- tokenizer : Optional [ Union [ str , PreTrainedTokenizerBase ]] = None ,
233- processor : Optional [ Union [ str , ProcessorMixin ]] = None ,
231+ model : str | PreTrainedModel ,
232+ distill_teacher : str | None = None ,
233+ config_name : str | None = None ,
234+ tokenizer : str | PreTrainedTokenizerBase | None = None ,
235+ processor : str | ProcessorMixin | None = None ,
234236 use_auth_token : bool = False ,
235237 precision : str = "auto" ,
236238 tie_word_embeddings : bool = True ,
237239 trust_remote_code_model : bool = False ,
238240 save_compressed : bool = True ,
239241 model_revision : str = "main" ,
240242 # Recipe arguments
241- recipe : Optional [ Union [ str , List [str ]]] = None ,
242- recipe_args : Optional [ List [ str ]] = None ,
243+ recipe : str | list [str ] | None = None ,
244+ recipe_args : list [ str ] | None = None ,
243245 clear_sparse_session : bool = False ,
244- stage : Optional [ str ] = None ,
246+ stage : str | None = None ,
245247 # Dataset arguments
246- dataset : Optional [ Union [ str , " Dataset" , " DatasetDict" ]] = None ,
247- dataset_config_name : Optional [ str ] = None ,
248- dataset_path : Optional [ str ] = None ,
249- splits : Optional [ Union [str , List , Dict ]] = None ,
248+ dataset : str | Dataset | DatasetDict | None = None ,
249+ dataset_config_name : str | None = None ,
250+ dataset_path : str | None = None ,
251+ splits : str | list [ str ] | dict [str , str ] | None = None ,
250252 num_calibration_samples : int = 512 ,
251253 shuffle_calibration_samples : bool = True ,
252254 max_seq_length : int = 384 ,
@@ -255,13 +257,13 @@ def oneshot(
255257 concatenate_data : bool = False ,
256258 streaming : bool = False ,
257259 overwrite_cache : bool = False ,
258- preprocessing_num_workers : Optional [ int ] = None ,
259- min_tokens_per_module : Optional [ float ] = None ,
260+ preprocessing_num_workers : int | None = None ,
261+ min_tokens_per_module : float | None = None ,
260262 moe_calibrate_all_experts : bool = True ,
261263 quantization_aware_calibration : bool = True ,
262264 # Miscellaneous arguments
263- output_dir : Optional [ str ] = None ,
264- log_dir : Optional [ str ] = None ,
265+ output_dir : str | None = None ,
266+ log_dir : str | None = None ,
265267 ** kwargs ,
266268) -> PreTrainedModel :
267269 """
@@ -290,7 +292,8 @@ def oneshot(
290292 tag, or commit id).
291293
292294 # Recipe arguments
293- :param recipe: Path to a LLM Compressor sparsification recipe.
295+ :param recipe: Path to a LLM Compressor recipe, or a list of paths
296+ to multiple LLM Compressor recipes.
294297 :param recipe_args: List of recipe arguments to evaluate, in the
295298 format "key1=value1", "key2=value2".
296299 :param clear_sparse_session: Whether to clear CompressionSession/
0 commit comments