-
Notifications
You must be signed in to change notification settings - Fork 292
[Oneshot] Add validation for empty dataset and enhance oneshot function parameters #1957
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -12,7 +12,7 @@ | |||||
| import os | ||||||
| from datetime import datetime | ||||||
| from pathlib import Path | ||||||
| from typing import TYPE_CHECKING | ||||||
| from typing import TYPE_CHECKING, Any, Callable | ||||||
|
|
||||||
| from loguru import logger | ||||||
| from torch.utils.data import DataLoader | ||||||
|
|
@@ -260,8 +260,16 @@ def oneshot( | |||||
| preprocessing_num_workers: int | None = None, | ||||||
| min_tokens_per_module: float | None = None, | ||||||
| moe_calibrate_all_experts: bool = True, | ||||||
| pipeline: str = "independent", | ||||||
| tracing_ignore: list[str] | None = None, | ||||||
| raw_kwargs: dict[str, Any] | None = None, | ||||||
| preprocessing_func: Callable | None = None, | ||||||
| max_train_samples: int | None = None, | ||||||
| remove_columns: list[str] | None = None, | ||||||
| dvc_data_repository: str | None = None, | ||||||
| quantization_aware_calibration: bool = True, | ||||||
| # Miscellaneous arguments | ||||||
| sequential_targets: list[str] | None = None, | ||||||
| # Miscellaneous arguments | ||||||
| output_dir: str | None = None, | ||||||
| log_dir: str | None = None, | ||||||
| **kwargs, | ||||||
|
|
@@ -331,6 +339,16 @@ def oneshot( | |||||
| during forward pass in calibration. When False, quantization is disabled | ||||||
| during forward pass in calibration. Default is set to True. | ||||||
|
|
||||||
| :param pipeline: The pipeline configuration to use for calibration. Options include | ||||||
| 'independent', 'sequential', or 'layer_sequential'. | ||||||
| :param tracing_ignore: List of module names to ignore during tracing. | ||||||
| :param raw_kwargs: Dictionary of raw keyword arguments passed to the function. | ||||||
| :param preprocessing_func: Optional callable for preprocessing the dataset. | ||||||
| :param max_train_samples: Maximum number of training samples to use. | ||||||
| :param remove_columns: List of column names to remove from the dataset. | ||||||
| :param dvc_data_repository: Path to the DVC data repository, if applicable. | ||||||
| :param sequential_targets: List of sequential targets for calibration. | ||||||
|
|
||||||
| # Miscellaneous arguments | ||||||
| :param output_dir: Path to save the output model after calibration. | ||||||
| Nothing is saved if None. | ||||||
|
|
@@ -340,10 +358,18 @@ def oneshot( | |||||
| :return: The calibrated PreTrainedModel | ||||||
| """ | ||||||
|
|
||||||
| # pass all args directly into Oneshot | ||||||
| if sequential_targets and pipeline == "independent": | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| raise ValueError( | ||||||
ArkaSanka marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| "Invalid configuration: " | ||||||
| "sequential_targets' cannot be used with 'independent' pipeline. " | ||||||
| "Please use 'sequential' or 'layer_sequential' pipeline when specifying " | ||||||
| "sequential_targets." | ||||||
| ) | ||||||
|
|
||||||
| local_args = { | ||||||
| k: v for k, v in locals().items() if k not in ("local_args", "kwargs") | ||||||
| } | ||||||
|
|
||||||
| one_shot = Oneshot(**local_args, **kwargs) | ||||||
| one_shot() | ||||||
|
|
||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,15 @@ | ||
| import logging | ||
|
|
||
| import pytest | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from llmcompressor import oneshot | ||
| from tests.llmcompressor.transformers.oneshot.dataset_processing import get_data_utils | ||
| from tests.testing_utils import parse_params | ||
|
|
||
| logging.basicConfig(level=logging.INFO) | ||
| logger = logging.getLogger(__name__) | ||
|
|
||
| CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/oneshot/oneshot_configs" | ||
|
|
||
| # TODO: Seems better to mark test type (smoke, sanity, regression) as a marker as | ||
|
|
@@ -42,15 +47,41 @@ def wrapped_preprocess_func(sample): | |
| dataset_config_name=config.get("dataset_config_name"), | ||
| ) | ||
|
|
||
| args["pipeline"] = config.get("pipeline", "independent") | ||
| args["sequential_targets"] = config.get("sequential_targets", None) | ||
| args["tracing_ignore"] = config.get("tracing_ignore", []) | ||
| args["raw_kwargs"] = config.get("raw_kwargs", {}) | ||
| args["preprocessing_func"] = config.get("preprocessing_func", lambda x: x) | ||
| args["max_train_samples"] = config.get("max_train_samples", 50) | ||
| args["remove_columns"] = config.get("remove_columns", None) | ||
| args["dvc_data_repository"] = config.get("dvc_data_repository", None) | ||
| args["splits"] = config.get("splits", {"calibration": "train[:50]"}) | ||
| args["log_dir"] = config.get("log_dir", "sparse_logs") | ||
|
|
||
| return args | ||
|
|
||
|
|
||
| @pytest.mark.smoke | ||
| @pytest.mark.integration | ||
| def test_one_shot_inputs(one_shot_args, tmp_path): | ||
| oneshot( | ||
| **one_shot_args, | ||
| output_dir=tmp_path, | ||
| num_calibration_samples=10, | ||
| pad_to_max_length=False, | ||
| ) | ||
| logger.info(f"Dataset type: {type(one_shot_args.get('dataset'))}") | ||
| if isinstance(one_shot_args.get("dataset"), str): | ||
| logger.info(f"Dataset name: {one_shot_args.get('dataset')}") | ||
| logger.info(f"Dataset config: {one_shot_args.get('dataset_config_name')}") | ||
| try: | ||
| # Call oneshot with all parameters as flat arguments | ||
| oneshot( | ||
| **one_shot_args, | ||
| output_dir=tmp_path, | ||
| num_calibration_samples=10, | ||
| pad_to_max_length=False, | ||
| ) | ||
|
|
||
| except ValueError as e: | ||
| if "num_samples should be a positive integer value" in str( | ||
| e | ||
| ) or "Dataset is empty. Cannot create a calibration dataloader" in str(e): | ||
| logger.warning(f"Dataset is empty: {one_shot_args.get('dataset')}") | ||
| pytest.skip(f"Dataset is empty: {one_shot_args.get('dataset')}") | ||
| else: | ||
| raise # Re-raise other ValueError exceptions | ||
|
Comment on lines
+67
to
+87
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you explain why you needed to add these changes? if you are asserting that a certain pathway raises an error in a test, you can do that with with pytest.raises(ValueError):there are examples of this in the code
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When running I couldn't define a pathway for this, so I handled these as exceptions. I'm open to suggestions to address this.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the error stacktrace? We should resolve it there instead. For example,
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Traceback (most recent call last):
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It runs fine on the latest main. I then added
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, leave as |
||
Uh oh!
There was an error while loading. Please reload this page.