-
Notifications
You must be signed in to change notification settings - Fork 294
[docs] Add autoround doc #2055
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+145
−0
Merged
[docs] Add autoround doc #2055
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
baad120
add init doc
yiliu30 e73b98d
update docs
yiliu30 03b420e
update docs
yiliu30 37a03ec
update
yiliu30 f993f80
update
yiliu30 374d991
update
yiliu30 73fdcf1
fix
yiliu30 ead09b3
update
yiliu30 9eedf88
add more
yiliu30 6e9623b
fix
yiliu30 87d7162
update
yiliu30 1a76018
Update examples/autoround/README.md
yiliu30 5e7df30
refine
yiliu30 627c0ee
Merge branch 'main' into autoround-doc
yiliu30 c5b7a1d
update
yiliu30 4023f15
Merge branch 'autoround-doc' of https://github.com/yiliu30/llm-compre…
yiliu30 e49cbb7
Merge branch 'main' into autoround-doc
HDCharles 39915b0
update readme
yiliu30 d02c95a
update
yiliu30 4a54ade
update
yiliu30 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,141 @@ | ||
| # `AutoRound` Quantization | ||
|
|
||
| `llm-compressor` supports [AutoRound](https://aclanthology.org/2024.findings-emnlp.662.pdf), an advanced quantization technique that delivers **high-accuracy**, **low-bit quantization**. The quantized results are fully compatible with `compressed-tensors` and can be served directly with vLLM. | ||
|
|
||
| AutoRound introduces three trainable parameters (V, α, and β) to optimize rounding values and clipping ranges during quantization. The method processes each decoder layer sequentially, using block-wise output reconstruction error as the training objective to fine-tune these parameters. This approach combines the efficiency of post-training quantization with the adaptability of parameter tuning, delivering robust compression for large language models while maintaining strong performance. | ||
|
|
||
| ## Installation | ||
|
|
||
| To get started, install: | ||
|
|
||
| ```bash | ||
| git clone https://github.com/vllm-project/llm-compressor.git | ||
| cd llm-compressor | ||
| pip install -e . | ||
| ``` | ||
|
|
||
| ## Quickstart | ||
|
|
||
| The example includes an end-to-end script for applying the AutoRound quantization algorithm. | ||
|
|
||
| ```bash | ||
| python3 llama3_example.py | ||
| ``` | ||
|
|
||
| The resulting model `Meta-Llama-3-8B-Instruct-W4A16-G128-AutoRound` is ready to be loaded into vLLM. | ||
|
|
||
| ## Code Walkthrough | ||
|
|
||
| Now, we will step through the code in the example. There are four steps: | ||
| 1) Load model | ||
| 2) Prepare calibration data | ||
| 3) Apply quantization | ||
| 4) Evaluate accuracy in vLLM | ||
|
|
||
| ### 1) Load Model | ||
|
|
||
| Load the model using `AutoModelForCausalLM` for handling quantized saving and loading. | ||
|
|
||
| ```python | ||
| from transformers import AutoTokenizer, AutoModelForCausalLM | ||
|
|
||
| MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" | ||
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") | ||
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | ||
| ``` | ||
|
|
||
| ### 2) Prepare Calibration Data | ||
|
|
||
| When quantizing model weights with AutoRound, you’ll need a small set of sample data to run the algorithm. By default, we are using [NeelNanda/pile-10k](https://huggingface.co/datasets/NeelNanda/pile-10k) as our calibration dataset. | ||
| Recommended starting points: | ||
| - 128 samples — typically sufficient for stable calibration (increase if accuracy degrades). | ||
| - 2048 sequence length — a good baseline for most LLMs. | ||
| - 200 tuning steps — usually enough to converge (increase if accuracy drops). | ||
|
|
||
| ```python | ||
| # Select calibration dataset. | ||
| from auto_round.calib_dataset import get_dataset | ||
|
|
||
| NUM_CALIBRATION_SAMPLES = 128 | ||
| MAX_SEQUENCE_LENGTH = 2048 | ||
|
|
||
| # Get aligned calibration dataset. | ||
| ds = get_dataset( | ||
| tokenizer=tokenizer, | ||
| seqlen=MAX_SEQUENCE_LENGTH, | ||
| nsamples=NUM_CALIBRATION_SAMPLES, | ||
| ) | ||
| ``` | ||
|
|
||
| ### 3) Apply Quantization | ||
|
|
||
| With the dataset ready, we will now apply AutoRound quantization to the model. | ||
|
|
||
| ```python | ||
| from llmcompressor import oneshot | ||
| from llmcompressor.modifiers.autoround import AutoRoundModifier | ||
|
|
||
| # Configure the quantization algorithm to run. | ||
| recipe = AutoRoundModifier( | ||
| targets="Linear", scheme="W4A16", ignore=["lm_head"], iters=200 | ||
| ) | ||
|
|
||
| # Apply quantization. | ||
| oneshot( | ||
| model=model, | ||
| dataset=ds, | ||
| recipe=recipe, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
| # disable shuffling to get slightly better mmlu score | ||
| shuffle_calibration_samples=False, | ||
| ) | ||
|
|
||
|
|
||
| # Save to disk compressed. | ||
| SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W4A16-G128-AutoRound" | ||
| model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
| tokenizer.save_pretrained(SAVE_DIR) | ||
| ``` | ||
|
|
||
| We have successfully created an `int4` model! | ||
|
|
||
| ### 4) Evaluate Accuracy | ||
|
|
||
| With the model created, we can now load and run in vLLM (after installing). | ||
|
|
||
| ```python | ||
| from vllm import LLM | ||
| model = LLM("./Meta-Llama-3-8B-Instruct-W4A16-G128-AutoRound") | ||
| ``` | ||
|
|
||
| We can evaluate accuracy with `lm_eval` (`pip install lm-eval==0.4.9.1`): | ||
| > Note: quantized models can be sensitive to the presence of the `bos` token. `lm_eval` does not add a `bos` token by default, so make sure to include the `add_bos_token=True` argument when running your evaluations. | ||
|
|
||
| Run the following to test accuracy on GSM-8K: | ||
|
|
||
| ```bash | ||
| lm_eval --model vllm \ | ||
| --model_args pretrained="./Meta-Llama-3-8B-Instruct-W4A16-G128-AutoRound",add_bos_token=true \ | ||
| --tasks gsm8k \ | ||
| --num_fewshot 5 \ | ||
| --limit 1000 \ | ||
| --batch_size 'auto' | ||
| ``` | ||
|
|
||
| We can see the resulting scores look good! | ||
|
|
||
| ```bash | ||
| | Tasks | Version | Filter | n-shot | Metric | | Value | | Stderr | | ||
| | ----- | ------: | ---------------- | -----: | ----------- | --- | ----: | --- | -----: | | ||
| | gsm8k | 3 | flexible-extract | 5 | exact_match | ↑ | 0.737 | ± | 0.0139 | | ||
| | | | strict-match | 5 | exact_match | ↑ | 0.736 | ± | 0.0139 | | ||
| ``` | ||
| > Note: quantized model accuracy may vary slightly due to nondeterminism. | ||
|
|
||
| ### Known Issues | ||
| Currently, `llm-compressor` supports applying AutoRound only on the `wNa16` quantization schemes. Support for additional schemes is planned. You can follow progress in the [RFC](https://github.com/vllm-project/llm-compressor/issues/1968). | ||
|
|
||
| ### Questions or Feature Request? | ||
|
|
||
| Please open up an issue on [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor) or [intel/auto-round](https://github.com/intel/auto-round). | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.