1212import os
1313from datetime import datetime
1414from pathlib import Path
15- from typing import TYPE_CHECKING
15+ from typing import TYPE_CHECKING , Any , Callable
1616
1717from loguru import logger
1818from torch .utils .data import DataLoader
@@ -260,8 +260,16 @@ def oneshot(
260260 preprocessing_num_workers : int | None = None ,
261261 min_tokens_per_module : float | None = None ,
262262 moe_calibrate_all_experts : bool = True ,
263+ pipeline : str = "independent" ,
264+ tracing_ignore : list [str ] | None = None ,
265+ raw_kwargs : dict [str , Any ] | None = None ,
266+ preprocessing_func : Callable | None = None ,
267+ max_train_samples : int | None = None ,
268+ remove_columns : list [str ] | None = None ,
269+ dvc_data_repository : str | None = None ,
263270 quantization_aware_calibration : bool = True ,
264- # Miscellaneous arguments
271+ sequential_targets : list [str ] | None = None ,
272+ # Miscellaneous arguments
265273 output_dir : str | None = None ,
266274 log_dir : str | None = None ,
267275 ** kwargs ,
@@ -331,6 +339,16 @@ def oneshot(
331339 during forward pass in calibration. When False, quantization is disabled
332340 during forward pass in calibration. Default is set to True.
333341
342+ :param pipeline: The pipeline configuration to use for calibration. Options include
343+ 'independent', 'sequential', or 'layer_sequential'.
344+ :param tracing_ignore: List of module names to ignore during tracing.
345+ :param raw_kwargs: Dictionary of raw keyword arguments passed to the function.
346+ :param preprocessing_func: Optional callable for preprocessing the dataset.
347+ :param max_train_samples: Maximum number of training samples to use.
348+ :param remove_columns: List of column names to remove from the dataset.
349+ :param dvc_data_repository: Path to the DVC data repository, if applicable.
350+ :param sequential_targets: List of sequential targets for calibration.
351+
334352 # Miscellaneous arguments
335353 :param output_dir: Path to save the output model after calibration.
336354 Nothing is saved if None.
@@ -340,10 +358,18 @@ def oneshot(
340358 :return: The calibrated PreTrainedModel
341359 """
342360
343- # pass all args directly into Oneshot
361+ if sequential_targets and pipeline == "independent" :
362+ raise ValueError (
363+ "Invalid configuration: "
364+ "sequential_targets' cannot be used with 'independent' pipeline. "
365+ "Please use 'sequential' or 'layer_sequential' pipeline when specifying "
366+ "sequential_targets."
367+ )
368+
344369 local_args = {
345370 k : v for k , v in locals ().items () if k not in ("local_args" , "kwargs" )
346371 }
372+
347373 one_shot = Oneshot (** local_args , ** kwargs )
348374 one_shot ()
349375
0 commit comments