Skip to content

Commit 9e1a040

Browse files
BigFaceBoyfangxuweigemini-code-assist[bot]dsikka
authored
add internvl3-8b-hf quantize example (#2028)
SUMMARY: LLM Compressor doesn't currently have any examples of InternVL3. It truely took me a lot of time to quantize it successful. So I want to share the example of it. TEST PLAN: "please outline how the changes were tested" --------- Signed-off-by: xuwei fang <977502733@qq.com> Co-authored-by: fangxuwei <fangxuwei@ezviz.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
1 parent 160c92d commit 9e1a040

File tree

2 files changed

+259
-0
lines changed

2 files changed

+259
-0
lines changed
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Quantizing InternVL3-8B-hf
2+
This file shows the example of quantizing InternVL3-8B-hf.
3+
4+
## Step 1: Compressing Your Own Model
5+
6+
```python
7+
model_id = "OpenGVLab/InternVL3-8B-hf"
8+
model = AutoModelForImageTextToText.from_pretrained(model_id, torch_dtype=torch.bfloat16)
9+
processor = AutoProcessor.from_pretrained(model_id)
10+
```
11+
12+
## Step 2: Load datasets
13+
Use the `ultrachat_200k` datasets.
14+
```python
15+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
16+
DATASET_SPLIT = "train_sft"
17+
NUM_CALIBRATION_SAMPLES = 256
18+
MAX_SEQUENCE_LENGTH = 512
19+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
20+
ds = ds.shuffle(seed=42)
21+
```
22+
23+
## Step 3: Preprocess and tokenize
24+
```python
25+
def preprocess_and_tokenize(example):
26+
messages = [
27+
{
28+
"role": "user",
29+
"content": [
30+
{
31+
"type": "text",
32+
"text": example["messages"]
33+
},
34+
],
35+
}
36+
]
37+
38+
inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt")
39+
return inputs
40+
41+
ds = ds.map(preprocess_and_tokenize)
42+
```
43+
## Step 4: Adding Your Own Data Collator
44+
We need custom data collation to satisfy the model-specific requirements.
45+
```python
46+
def data_collator(batch):
47+
assert len(batch) == 1
48+
item = {key: value for key, value in batch[0].items()}
49+
item["attention_mask"] = torch.tensor([item["attention_mask"]])
50+
item["input_ids"] = torch.LongTensor([item["input_ids"]])
51+
52+
return item
53+
```
54+
55+
56+
## Step 5: Define the recipe
57+
```python
58+
recipe = GPTQModifier(
59+
targets="Linear",
60+
scheme="FP8",
61+
ignore=["re:.*lm_head", "re:.*vision_tower.*", "re:.*multi_modal_projector.*"]
62+
)
63+
```
64+
Note: We also tried `ignore=["re:.*lm_head", "re:.*multi_modal_projector.*"]`. However, this quantized model did not produce meaningful output for prompts with images. Therefore, we only quantize the LLM part.
65+
## Step 6: Oneshot and save
66+
```python
67+
oneshot(
68+
model=model,
69+
tokenizer=model_id,
70+
dataset=ds,
71+
recipe=recipe,
72+
max_seq_length=MAX_SEQUENCE_LENGTH,
73+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
74+
trust_remote_code_model=True,
75+
data_collator=data_collator
76+
)
77+
78+
SAVE_DIR = "OpenGVLab/InternVL3-8B-hf-FP8-GPTQ"
79+
model.save_pretrained(SAVE_DIR, save_compressed=True)
80+
processor.save_pretrained(SAVE_DIR)
81+
```
82+
## Step 7: Evaluate
83+
With the model created, we can now load and run in vLLM.
84+
### Accuracy
85+
We can evaluate accuracy multimodal_vision model with [VLMEvalKit](https://github.com/open-compass/VLMEvalKit.git)
86+
```
87+
torchrun --nproc-per-node=2 run.py --data MMStar --model InternVL3-8B-hf_FP8_GPTQ --verbose
88+
```
89+
### Performance
90+
We can evaluate performance with vllm.
91+
First, run in vllm.
92+
```
93+
vllm serve OpenGVLab/InternVL3-8B-hf-FP8-GPTQ \
94+
--served-model-name InternVL3-8B-hf-FP8-GPTQ \
95+
--gpu-memory-utilization 0.9 \
96+
--uvicorn_log_level error \
97+
--disable-log-stats \
98+
--trust-remote-code \
99+
--allowed-local-media-path /path/to/sharegpt4v/images \
100+
--limit-mm-per-prompt '{"image": 20}' \
101+
--mm-processor-kwargs '{"max_dynamic_patch": 1}' \
102+
--no-enable-prefix-caching \
103+
--disable-mm-preprocessor-cache \
104+
--max-model-len 6144
105+
```
106+
Second, use vllm bench serve.
107+
```
108+
vllm bench serve \
109+
--backend openai-chat \
110+
--dataset-name sharegpt \
111+
--dataset-path /path/to//ShareGPT4V/sharegpt4v_instruct_gpt4-vision_cap100k_coco.json \
112+
--num-prompts 500 \
113+
--endpoint /v1/chat/completions \
114+
--max-concurrency 100 \
115+
--percentile-metrics='ttft,tpot,itl,e2el' \
116+
--model InternVL3-8B-hf-FP8-GPTQ
117+
```
118+
119+
The result of InternVL3-8B-hf:
120+
```
121+
============ Serving Benchmark Result ============
122+
Successful requests: 500
123+
Maximum request concurrency: 100
124+
Benchmark duration (s): 251.15
125+
Total input tokens: 6193
126+
Total generated tokens: 30487
127+
Request throughput (req/s): 1.99
128+
Output token throughput (tok/s): 121.39
129+
Peak output token throughput (tok/s): 1055.00
130+
Peak concurrent requests: 107.00
131+
Total Token throughput (tok/s): 146.05
132+
---------------Time to First Token----------------
133+
Mean TTFT (ms): 25349.32
134+
Median TTFT (ms): 25498.75
135+
P99 TTFT (ms): 45969.02
136+
-----Time per Output Token (excl. 1st token)------
137+
Mean TPOT (ms): 415.96
138+
Median TPOT (ms): 414.47
139+
P99 TPOT (ms): 857.11
140+
---------------Inter-token Latency----------------
141+
Mean ITL (ms): 379.15
142+
Median ITL (ms): 410.25
143+
P99 ITL (ms): 524.16
144+
----------------End-to-end Latency----------------
145+
Mean E2EL (ms): 48444.73
146+
Median E2EL (ms): 50202.88
147+
P99 E2EL (ms): 91705.52
148+
==================================================
149+
```
150+
151+
The result of InternVL3-8B-hf-FP8-GPTQ:
152+
```
153+
============ Serving Benchmark Result ============
154+
Successful requests: 500
155+
Maximum request concurrency: 100
156+
Benchmark duration (s): 163.36
157+
Total input tokens: 6193
158+
Total generated tokens: 34831
159+
Request throughput (req/s): 3.06
160+
Output token throughput (tok/s): 213.22
161+
Peak output token throughput (tok/s): 1787.00
162+
Peak concurrent requests: 109.00
163+
Total Token throughput (tok/s): 251.13
164+
---------------Time to First Token----------------
165+
Mean TTFT (ms): 14510.84
166+
Median TTFT (ms): 14371.25
167+
P99 TTFT (ms): 28978.21
168+
-----Time per Output Token (excl. 1st token)------
169+
Mean TPOT (ms): 257.52
170+
Median TPOT (ms): 270.19
171+
P99 TPOT (ms): 330.58
172+
---------------Inter-token Latency----------------
173+
Mean ITL (ms): 247.34
174+
Median ITL (ms): 268.16
175+
P99 ITL (ms): 386.99
176+
----------------End-to-end Latency----------------
177+
Mean E2EL (ms): 31725.93
178+
Median E2EL (ms): 32227.14
179+
P99 E2EL (ms): 64293.40
180+
==================================================
181+
```
182+
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import torch
2+
from datasets import load_dataset
3+
from transformers import AutoModelForImageTextToText, AutoProcessor
4+
5+
from llmcompressor import oneshot
6+
from llmcompressor.modifiers.quantization import GPTQModifier
7+
8+
# Load model.
9+
model_id = "OpenGVLab/InternVL3-8B-hf"
10+
model = AutoModelForImageTextToText.from_pretrained(
11+
model_id, torch_dtype=torch.bfloat16
12+
)
13+
processor = AutoProcessor.from_pretrained(model_id)
14+
15+
# Load datasets
16+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
17+
DATASET_SPLIT = "train_sft"
18+
NUM_CALIBRATION_SAMPLES = 256
19+
MAX_SEQUENCE_LENGTH = 512
20+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
21+
ds = ds.shuffle(seed=42)
22+
23+
24+
def preprocess_and_tokenize(example):
25+
messages = [
26+
{
27+
"role": "user",
28+
"content": [
29+
{"type": "text", "text": example["messages"]},
30+
],
31+
}
32+
]
33+
inputs = processor.apply_chat_template(
34+
messages,
35+
add_generation_prompt=True,
36+
tokenize=True,
37+
return_dict=True,
38+
return_tensors="pt",
39+
)
40+
return inputs
41+
42+
43+
ds = ds.map(preprocess_and_tokenize)
44+
45+
46+
def data_collator(batch):
47+
assert len(batch) == 1
48+
item = {key: value for key, value in batch[0].items()}
49+
item["attention_mask"] = torch.tensor([item["attention_mask"]])
50+
item["input_ids"] = torch.LongTensor([item["input_ids"]])
51+
52+
return item
53+
54+
55+
# Recipe
56+
recipe = GPTQModifier(
57+
targets="Linear",
58+
scheme="FP8",
59+
ignore=["re:.*lm_head", "re:.*vision_tower.*", "re:.*multi_modal_projector.*"],
60+
)
61+
62+
# Perform oneshot
63+
oneshot(
64+
model=model,
65+
tokenizer=model_id,
66+
dataset=ds,
67+
recipe=recipe,
68+
max_seq_length=MAX_SEQUENCE_LENGTH,
69+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
70+
trust_remote_code_model=True,
71+
data_collator=data_collator,
72+
)
73+
74+
# Save to disk compressed.
75+
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-FP8"
76+
model.save_pretrained(SAVE_DIR, save_compressed=True)
77+
processor.save_pretrained(SAVE_DIR)

0 commit comments

Comments
 (0)