|
| 1 | +# Quantizing InternVL3-8B-hf |
| 2 | +This file shows the example of quantizing InternVL3-8B-hf. |
| 3 | + |
| 4 | +## Step 1: Compressing Your Own Model |
| 5 | + |
| 6 | +```python |
| 7 | +model_id = "OpenGVLab/InternVL3-8B-hf" |
| 8 | +model = AutoModelForImageTextToText.from_pretrained(model_id, torch_dtype=torch.bfloat16) |
| 9 | +processor = AutoProcessor.from_pretrained(model_id) |
| 10 | +``` |
| 11 | + |
| 12 | +## Step 2: Load datasets |
| 13 | +Use the `ultrachat_200k` datasets. |
| 14 | +```python |
| 15 | +DATASET_ID = "HuggingFaceH4/ultrachat_200k" |
| 16 | +DATASET_SPLIT = "train_sft" |
| 17 | +NUM_CALIBRATION_SAMPLES = 256 |
| 18 | +MAX_SEQUENCE_LENGTH = 512 |
| 19 | +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") |
| 20 | +ds = ds.shuffle(seed=42) |
| 21 | +``` |
| 22 | + |
| 23 | +## Step 3: Preprocess and tokenize |
| 24 | +```python |
| 25 | +def preprocess_and_tokenize(example): |
| 26 | + messages = [ |
| 27 | + { |
| 28 | + "role": "user", |
| 29 | + "content": [ |
| 30 | + { |
| 31 | + "type": "text", |
| 32 | + "text": example["messages"] |
| 33 | + }, |
| 34 | + ], |
| 35 | + } |
| 36 | + ] |
| 37 | + |
| 38 | + inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt") |
| 39 | + return inputs |
| 40 | + |
| 41 | +ds = ds.map(preprocess_and_tokenize) |
| 42 | +``` |
| 43 | +## Step 4: Adding Your Own Data Collator |
| 44 | +We need custom data collation to satisfy the model-specific requirements. |
| 45 | +```python |
| 46 | +def data_collator(batch): |
| 47 | + assert len(batch) == 1 |
| 48 | + item = {key: value for key, value in batch[0].items()} |
| 49 | + item["attention_mask"] = torch.tensor([item["attention_mask"]]) |
| 50 | + item["input_ids"] = torch.LongTensor([item["input_ids"]]) |
| 51 | + |
| 52 | + return item |
| 53 | +``` |
| 54 | + |
| 55 | + |
| 56 | +## Step 5: Define the recipe |
| 57 | +```python |
| 58 | +recipe = GPTQModifier( |
| 59 | + targets="Linear", |
| 60 | + scheme="FP8", |
| 61 | + ignore=["re:.*lm_head", "re:.*vision_tower.*", "re:.*multi_modal_projector.*"] |
| 62 | + ) |
| 63 | +``` |
| 64 | +Note: We also tried `ignore=["re:.*lm_head", "re:.*multi_modal_projector.*"]`. However, this quantized model did not produce meaningful output for prompts with images. Therefore, we only quantize the LLM part. |
| 65 | +## Step 6: Oneshot and save |
| 66 | +```python |
| 67 | +oneshot( |
| 68 | + model=model, |
| 69 | + tokenizer=model_id, |
| 70 | + dataset=ds, |
| 71 | + recipe=recipe, |
| 72 | + max_seq_length=MAX_SEQUENCE_LENGTH, |
| 73 | + num_calibration_samples=NUM_CALIBRATION_SAMPLES, |
| 74 | + trust_remote_code_model=True, |
| 75 | + data_collator=data_collator |
| 76 | +) |
| 77 | + |
| 78 | +SAVE_DIR = "OpenGVLab/InternVL3-8B-hf-FP8-GPTQ" |
| 79 | +model.save_pretrained(SAVE_DIR, save_compressed=True) |
| 80 | +processor.save_pretrained(SAVE_DIR) |
| 81 | +``` |
| 82 | +## Step 7: Evaluate |
| 83 | +With the model created, we can now load and run in vLLM. |
| 84 | +### Accuracy |
| 85 | +We can evaluate accuracy multimodal_vision model with [VLMEvalKit](https://github.com/open-compass/VLMEvalKit.git) |
| 86 | +``` |
| 87 | +torchrun --nproc-per-node=2 run.py --data MMStar --model InternVL3-8B-hf_FP8_GPTQ --verbose |
| 88 | +``` |
| 89 | +### Performance |
| 90 | +We can evaluate performance with vllm. |
| 91 | +First, run in vllm. |
| 92 | +``` |
| 93 | +vllm serve OpenGVLab/InternVL3-8B-hf-FP8-GPTQ \ |
| 94 | + --served-model-name InternVL3-8B-hf-FP8-GPTQ \ |
| 95 | + --gpu-memory-utilization 0.9 \ |
| 96 | + --uvicorn_log_level error \ |
| 97 | + --disable-log-stats \ |
| 98 | + --trust-remote-code \ |
| 99 | + --allowed-local-media-path /path/to/sharegpt4v/images \ |
| 100 | + --limit-mm-per-prompt '{"image": 20}' \ |
| 101 | + --mm-processor-kwargs '{"max_dynamic_patch": 1}' \ |
| 102 | + --no-enable-prefix-caching \ |
| 103 | + --disable-mm-preprocessor-cache \ |
| 104 | + --max-model-len 6144 |
| 105 | +``` |
| 106 | +Second, use vllm bench serve. |
| 107 | +``` |
| 108 | +vllm bench serve \ |
| 109 | + --backend openai-chat \ |
| 110 | + --dataset-name sharegpt \ |
| 111 | + --dataset-path /path/to//ShareGPT4V/sharegpt4v_instruct_gpt4-vision_cap100k_coco.json \ |
| 112 | + --num-prompts 500 \ |
| 113 | + --endpoint /v1/chat/completions \ |
| 114 | + --max-concurrency 100 \ |
| 115 | + --percentile-metrics='ttft,tpot,itl,e2el' \ |
| 116 | + --model InternVL3-8B-hf-FP8-GPTQ |
| 117 | +``` |
| 118 | + |
| 119 | +The result of InternVL3-8B-hf: |
| 120 | +``` |
| 121 | +============ Serving Benchmark Result ============ |
| 122 | +Successful requests: 500 |
| 123 | +Maximum request concurrency: 100 |
| 124 | +Benchmark duration (s): 251.15 |
| 125 | +Total input tokens: 6193 |
| 126 | +Total generated tokens: 30487 |
| 127 | +Request throughput (req/s): 1.99 |
| 128 | +Output token throughput (tok/s): 121.39 |
| 129 | +Peak output token throughput (tok/s): 1055.00 |
| 130 | +Peak concurrent requests: 107.00 |
| 131 | +Total Token throughput (tok/s): 146.05 |
| 132 | +---------------Time to First Token---------------- |
| 133 | +Mean TTFT (ms): 25349.32 |
| 134 | +Median TTFT (ms): 25498.75 |
| 135 | +P99 TTFT (ms): 45969.02 |
| 136 | +-----Time per Output Token (excl. 1st token)------ |
| 137 | +Mean TPOT (ms): 415.96 |
| 138 | +Median TPOT (ms): 414.47 |
| 139 | +P99 TPOT (ms): 857.11 |
| 140 | +---------------Inter-token Latency---------------- |
| 141 | +Mean ITL (ms): 379.15 |
| 142 | +Median ITL (ms): 410.25 |
| 143 | +P99 ITL (ms): 524.16 |
| 144 | +----------------End-to-end Latency---------------- |
| 145 | +Mean E2EL (ms): 48444.73 |
| 146 | +Median E2EL (ms): 50202.88 |
| 147 | +P99 E2EL (ms): 91705.52 |
| 148 | +================================================== |
| 149 | +``` |
| 150 | + |
| 151 | +The result of InternVL3-8B-hf-FP8-GPTQ: |
| 152 | +``` |
| 153 | +============ Serving Benchmark Result ============ |
| 154 | +Successful requests: 500 |
| 155 | +Maximum request concurrency: 100 |
| 156 | +Benchmark duration (s): 163.36 |
| 157 | +Total input tokens: 6193 |
| 158 | +Total generated tokens: 34831 |
| 159 | +Request throughput (req/s): 3.06 |
| 160 | +Output token throughput (tok/s): 213.22 |
| 161 | +Peak output token throughput (tok/s): 1787.00 |
| 162 | +Peak concurrent requests: 109.00 |
| 163 | +Total Token throughput (tok/s): 251.13 |
| 164 | +---------------Time to First Token---------------- |
| 165 | +Mean TTFT (ms): 14510.84 |
| 166 | +Median TTFT (ms): 14371.25 |
| 167 | +P99 TTFT (ms): 28978.21 |
| 168 | +-----Time per Output Token (excl. 1st token)------ |
| 169 | +Mean TPOT (ms): 257.52 |
| 170 | +Median TPOT (ms): 270.19 |
| 171 | +P99 TPOT (ms): 330.58 |
| 172 | +---------------Inter-token Latency---------------- |
| 173 | +Mean ITL (ms): 247.34 |
| 174 | +Median ITL (ms): 268.16 |
| 175 | +P99 ITL (ms): 386.99 |
| 176 | +----------------End-to-end Latency---------------- |
| 177 | +Mean E2EL (ms): 31725.93 |
| 178 | +Median E2EL (ms): 32227.14 |
| 179 | +P99 E2EL (ms): 64293.40 |
| 180 | +================================================== |
| 181 | +``` |
| 182 | + |
0 commit comments