Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Big updates have landed in LLM Compressor! To get a more in-depth look, check ou

Some of the exciting new features include:

* **AutoRound Quantization Support**: Added [`AutoRoundModifier`](examples/autoround/llama3_example.py) for quantization using [AutoRound](https://aclanthology.org/2024.findings-emnlp.662.pdf), an advanced post-training algorithm that optimizes rounding and clipping ranges through sign-gradient descent. This approach combines the efficiency of post-training quantization with the adaptability of parameter tuning, delivering robust compression for large language models while maintaining strong performance.
* **Qwen3 Next and Qwen3 VL MoE Quantization Support**: Quantize the Qwen3 Next and Qwen3 VL MoE models and seamlessly run the models in vLLM. Examples for [NVFP4](examples/quantization_w4a4_fp4/qwen3_next_example.py) and [FP8](examples/quantization_w8a8_fp8/qwen3_next_example.py) Quantization have been added for the Qwen3-Next-80B-A3B-Instruct. For the Qwen3 VL MoE, support has been added for the datafree pathway, specifically [FP8 Quantization](examples/quantization_w8a8_fp8/qwen3_vl_moe_fp8_example.py) (e.g channel-wise and block-wise quantization). NOTE: these models are not supported in tranformers<=4.56.2. You may need to install transformers from source.
* **Quantization with Multiple Modifiers**: Multiple quantization modifiers can now be applied to the same model for mixed-precision quantization, for example applying AWQ W4A16 to a model's `self_attn` layers and GPTQ W8A8 to its `mlp` layers. This is an advanced usage of `llm-compressor` and an active area of research. See the [non-uniform quantization support](examples/quantization_non_uniform) section for more detail and [example usage](examples/quantization_non_uniform/quantization_multiple_modifiers.py).
* **QuIP and SpinQuant-style Transforms**: The newly added [`QuIPModifier`](examples/transform/quip_example.py) and [`SpinQuantModifier`](examples/transform/spinquant_example.py) allow users to quantize their models after injecting hadamard weights into the computation graph, reducing quantization error and greatly improving accuracy recovery for low bit weight and activation quantization.
Expand All @@ -55,6 +56,7 @@ Some of the exciting new features include:
* AWQ
* SmoothQuant
* SparseGPT
* AutoRound

### When to Use Which Optimization

Expand All @@ -78,6 +80,7 @@ Applying quantization with `llmcompressor`:
* [Weight only quantization to `fp4`](examples/quantization_w4a16_fp4/llama3_example.py)
* [Weight only quantization to `int4` using GPTQ](examples/quantization_w4a16/README.md)
* [Weight only quantization to `int4` using AWQ](examples/awq/README.md)
* [Weight only quantization to `int4` using AutoRound](examples/autoround/README.md)
* [Quantizing MoE LLMs](examples/quantizing_moe/README.md)
* [Quantizing Vision-Language Models](examples/multimodal_vision/README.md)
* [Quantizing Audio-Language Models](examples/multimodal_audio/README.md)
Expand Down
1 change: 1 addition & 0 deletions docs/getting-started/compress.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Compression schemes use quantization methods including the following:
| **AWQ** | Uses channelwise scaling to better preserve important outliers in weights and activations | Better accuracy recovery with faster runtime than GPTQ |
| **SmoothQuant** | Smooths outliers in activations by folding them into weights, ensuring better accuracy for weight and activation quantized models | Good accuracy recovery with minimal calibration time; composable with other methods |
| **Round-To-Nearest (RTN)** | Simple quantization technique that rounds each value to the nearest representable level in the target precision. | Provides moderate accuracy recovery in most scenarios. Computationally cheap and fast to implement, making it suitable for real-time or resource-constrained environments. |
| **AutoRound** | AutoRound optimizes rounding and clipping ranges via sign-gradient descent. | Delivers leading 4-bit and superior sub-4-bit accuracy compared to GPTQ/AWQ, with runtime faster than GPTQ and on par with AWQ. |

For this guide, we'll use `GPTQ` composed with `SmoothQuant` to create an `INT W8A8` quantized model. This combination provides a good balance for performance, accuracy, and compatability across a wide range of hardware.

Expand Down
141 changes: 141 additions & 0 deletions examples/autoround/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# `AutoRound` Quantization

`llm-compressor` supports [AutoRound](https://aclanthology.org/2024.findings-emnlp.662.pdf), an advanced quantization technique that delivers **high-accuracy**, **low-bit quantization**. The quantized results are fully compatible with `compressed-tensors` and can be served directly with vLLM.

AutoRound introduces three trainable parameters (V, α, and β) to optimize rounding values and clipping ranges during quantization. The method processes each decoder layer sequentially, using block-wise output reconstruction error as the training objective to fine-tune these parameters. This approach combines the efficiency of post-training quantization with the adaptability of parameter tuning, delivering robust compression for large language models while maintaining strong performance.

## Installation

To get started, install:

```bash
git clone https://github.com/vllm-project/llm-compressor.git
cd llm-compressor
pip install -e .
```

## Quickstart

The example includes an end-to-end script for applying the AutoRound quantization algorithm.

```bash
python3 llama3_example.py
```

The resulting model `Meta-Llama-3-8B-Instruct-W4A16-G128-AutoRound` is ready to be loaded into vLLM.

## Code Walkthrough

Now, we will step through the code in the example. There are four steps:
1) Load model
2) Prepare calibration data
3) Apply quantization
4) Evaluate accuracy in vLLM

### 1) Load Model

Load the model using `AutoModelForCausalLM` for handling quantized saving and loading.

```python
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
```

### 2) Prepare Calibration Data

When quantizing model weights with AutoRound, you’ll need a small set of sample data to run the algorithm. By default, we are using [NeelNanda/pile-10k](https://huggingface.co/datasets/NeelNanda/pile-10k) as our calibration dataset.
Recommended starting points:
- 128 samples — typically sufficient for stable calibration (increase if accuracy degrades).
- 2048 sequence length — a good baseline for most LLMs.
- 200 tuning steps — usually enough to converge (increase if accuracy drops).

```python
# Select calibration dataset.
from auto_round.calib_dataset import get_dataset

NUM_CALIBRATION_SAMPLES = 128
MAX_SEQUENCE_LENGTH = 2048

# Get aligned calibration dataset.
ds = get_dataset(
tokenizer=tokenizer,
seqlen=MAX_SEQUENCE_LENGTH,
nsamples=NUM_CALIBRATION_SAMPLES,
)
```

### 3) Apply Quantization

With the dataset ready, we will now apply AutoRound quantization to the model.

```python
from llmcompressor import oneshot
from llmcompressor.modifiers.autoround import AutoRoundModifier

# Configure the quantization algorithm to run.
recipe = AutoRoundModifier(
targets="Linear", scheme="W4A16", ignore=["lm_head"], iters=200
)

# Apply quantization.
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
# disable shuffling to get slightly better mmlu score
shuffle_calibration_samples=False,
)


# Save to disk compressed.
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W4A16-G128-AutoRound"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
```

We have successfully created an `int4` model!

### 4) Evaluate Accuracy

With the model created, we can now load and run in vLLM (after installing).

```python
from vllm import LLM
model = LLM("./Meta-Llama-3-8B-Instruct-W4A16-G128-AutoRound")
```

We can evaluate accuracy with `lm_eval` (`pip install lm-eval==0.4.9.1`):
> Note: quantized models can be sensitive to the presence of the `bos` token. `lm_eval` does not add a `bos` token by default, so make sure to include the `add_bos_token=True` argument when running your evaluations.

Run the following to test accuracy on GSM-8K:

```bash
lm_eval --model vllm \
--model_args pretrained="./Meta-Llama-3-8B-Instruct-W4A16-G128-AutoRound",add_bos_token=true \
--tasks gsm8k \
--num_fewshot 5 \
--limit 1000 \
--batch_size 'auto'
```

We can see the resulting scores look good!

```bash
| Tasks | Version | Filter | n-shot | Metric | | Value | | Stderr |
| ----- | ------: | ---------------- | -----: | ----------- | --- | ----: | --- | -----: |
| gsm8k | 3 | flexible-extract | 5 | exact_match | ↑ | 0.737 | ± | 0.0139 |
| | | strict-match | 5 | exact_match | ↑ | 0.736 | ± | 0.0139 |
```
> Note: quantized model accuracy may vary slightly due to nondeterminism.

### Known Issues
Currently, `llm-compressor` supports applying AutoRound only on the `wNa16` quantization schemes. Support for additional schemes is planned. You can follow progress in the [RFC](https://github.com/vllm-project/llm-compressor/issues/1968).

### Questions or Feature Request?

Please open up an issue on [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor) or [intel/auto-round](https://github.com/intel/auto-round).