Skip to content

Commit 1c85a66

Browse files
authored
[model_free_ptq] Add pathway for day-zero weight quantization support (#1971)
## Purpose ## * Create a pathway which can quantize model weights without needing a model definition or the use of a calibration pipeline. Such a pathway provides fast and reliable support for models which: * Do not have a HF model definition yet * Have complications with sequential pipelines (very large vision towers, tracing failure, long calibration runtime) ## Usage ## ```python model_free_ptq( model_stub="meta-llama/Llama-3.2-1B-Instruct", save_directory="Llama-3.2-1B-Instruct-FP8_block", scheme="FP8_BLOCK", ignore=["model.embed_tokens", "lm_head"], max_workers=15, device="cuda:0", ): ``` ## Testing ## * Added `test_model_free_ptq_matches_oneshot` which tests that saved tensors and configs exactly match between `model_free_ptq` and `oneshot` entrypoints for the same arguments. This test takes about 10 seconds to run. ## Future Extensions ## * Mixed-precision quantization (multiple recipes/targets) * Multi-GPU support (work is already parallelized by threads, but if GPU is the bottleneck we can split the work across GPUs) * Multi-process support (is python processing is the bottleneck, we can replace multithreading with multiprocessing) --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent df4d206 commit 1c85a66

File tree

8 files changed

+581
-1
lines changed

8 files changed

+581
-1
lines changed

src/llmcompressor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,4 @@
2626
create_session,
2727
reset_session,
2828
)
29-
from llmcompressor.entrypoints import Oneshot, oneshot, train
29+
from llmcompressor.entrypoints import Oneshot, oneshot, train, model_free_ptq

src/llmcompressor/entrypoints/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@
99

1010
from .oneshot import Oneshot, oneshot
1111
from .train import train
12+
from .model_free import model_free_ptq
1213
from .utils import post_process, pre_process
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import os
2+
import shutil
3+
from concurrent.futures import ThreadPoolExecutor, as_completed
4+
from pathlib import Path
5+
from typing import Optional
6+
7+
import torch
8+
import tqdm
9+
from compressed_tensors.quantization import QuantizationScheme
10+
from compressed_tensors.utils.match import _match_name
11+
from loguru import logger
12+
from safetensors.torch import load_file, save_file
13+
14+
from llmcompressor.entrypoints.model_free.helpers import (
15+
gpu_if_available,
16+
validate_scheme,
17+
)
18+
from llmcompressor.entrypoints.model_free.lifecycle import (
19+
calibrate_weights,
20+
compress_module,
21+
initialize_quantized_linear,
22+
)
23+
from llmcompressor.entrypoints.model_free.model_utils import (
24+
get_checkpoint_files,
25+
is_weights_file,
26+
)
27+
from llmcompressor.entrypoints.model_free.save_utils import (
28+
update_config,
29+
update_safetensors_index,
30+
)
31+
32+
__all__ = ["model_free_ptq"]
33+
34+
35+
def model_free_ptq(
36+
model_stub: str | os.PathLike,
37+
save_directory: str | os.PathLike,
38+
scheme: QuantizationScheme | str,
39+
ignore: Optional[list[str]] = None,
40+
max_workers: int = 1,
41+
device: Optional[torch.device | str] = None,
42+
):
43+
"""
44+
Quantize a model without the need for a model definition. This function operates on
45+
a model stub or folder containing weights saved in safetensors files
46+
47+
:param model_stub: huggingface model hub or path to local weights files
48+
:param scheme: weight quantization scheme or preset scheme name
49+
:param ignore: modules to ignore. Modules ending with "norm" are automatically
50+
ignored
51+
:param max_workers: number of worker threads to process files with
52+
:param device: gpu device to accelerate quantization with
53+
"""
54+
# validate arguments
55+
model_files = get_checkpoint_files(model_stub)
56+
scheme_name, scheme = validate_scheme(scheme)
57+
device = gpu_if_available(device)
58+
59+
# 0. collect safetensors files, copy files
60+
jobs = []
61+
for file_path, resolved_path in model_files:
62+
save_path = Path(save_directory) / file_path
63+
64+
if file_path.endswith("safetensors"):
65+
jobs.append(
66+
(_process_file, resolved_path, save_path, scheme, ignore, device)
67+
)
68+
69+
else:
70+
if is_weights_file(file_path):
71+
logger.warning(f"Skipping weights file {file_path}")
72+
save_path.parent.mkdir(parents=True, exist_ok=True)
73+
logger.info(f"Copying {file_path} {save_path}")
74+
shutil.copyfile(resolved_path, save_path)
75+
76+
# 1-4. quantize and compress weights
77+
with ThreadPoolExecutor(max_workers) as executor:
78+
futures = [executor.submit(*job) for job in jobs]
79+
80+
total_size = 0
81+
weight_map = dict()
82+
for future in tqdm.tqdm(
83+
as_completed(futures), total=len(futures), desc="Quantizing"
84+
):
85+
_total_size, _weight_map = future.result()
86+
total_size += _total_size
87+
weight_map.update(_weight_map)
88+
89+
# 5. update config and safetensors index
90+
update_config(save_directory, scheme_name, scheme, ignore)
91+
update_safetensors_index(save_directory, total_size, weight_map)
92+
93+
94+
def _process_file(
95+
file_path: str | os.PathLike,
96+
save_path: str | os.PathLike,
97+
scheme: QuantizationScheme,
98+
ignore: str | list[str],
99+
device: str | torch.device,
100+
) -> tuple[int, dict[str, str]]:
101+
"""
102+
Quantize and compress tensors in a given safetensors file
103+
104+
:param file_path: safetensors file to process
105+
:param save_path: save path of file with quantized weights
106+
:param scheme: quantization scheme to apply to tensors
107+
:param ignore: modules to ignore. Modules ending with "norm" are automatically
108+
ignored
109+
:param device: device used to quantize and compress weights
110+
"""
111+
tensors = load_file(file_path)
112+
113+
for name in list(tensors.keys()):
114+
module_name, param_name = name.rsplit(".", 1)
115+
is_linear_weight = param_name == "weight" and not module_name.endswith("norm")
116+
is_ignored = any(_match_name(module_name, ign) for ign in ignore)
117+
if not is_linear_weight or is_ignored:
118+
continue
119+
120+
# 1. initialize module with qparams (on device)
121+
module = initialize_quantized_linear(tensors[name], scheme, device)
122+
123+
# 2. calibrate weight qparams
124+
calibrate_weights(module)
125+
126+
# 3. compress module using qparams
127+
compress_module(module)
128+
129+
# 4. save compressed data (on cpu)
130+
del tensors[name]
131+
prefix = module_name + "."
132+
for key, value in module.state_dict(prefix=prefix).items():
133+
tensors[key] = value.to("cpu")
134+
135+
save_file(tensors, save_path)
136+
total_size = sum(tensor.nbytes for tensor in tensors.values())
137+
weight_map = {key: os.path.basename(save_path) for key in tensors.keys()}
138+
return total_size, weight_map
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from typing import Optional
2+
3+
import torch
4+
from compressed_tensors.quantization import QuantizationScheme, preset_name_to_scheme
5+
from compressed_tensors.utils import getattr_chain
6+
from compressed_tensors.utils.match import _match_name
7+
from loguru import logger
8+
9+
__all__ = ["validate_scheme", "gpu_if_available", "is_match_name"]
10+
11+
12+
def validate_scheme(scheme: QuantizationScheme) -> tuple[str, QuantizationScheme]:
13+
# treat strings as preset schemes
14+
if isinstance(scheme, str):
15+
scheme_name, scheme = scheme, preset_name_to_scheme(scheme, [])
16+
else:
17+
scheme_name = "config_group_0"
18+
19+
# weight quantization must be provided
20+
if scheme.weights is None:
21+
raise ValueError(
22+
"Must provide a weights quanitization scheme to perform weights-only PTQ"
23+
)
24+
25+
# activation quantization must be dynamic
26+
input_dynamic = getattr_chain(scheme, "input_activations.dynamic", True)
27+
output_dynamic = getattr_chain(scheme, "output_activations.dynamic", True)
28+
if input_dynamic is not True or output_dynamic is not True:
29+
raise ValueError(
30+
"Model Free PTQ cannot calibrate activations. "
31+
"Please use `oneshot` instead."
32+
)
33+
34+
# override with static observers
35+
# Remove after https://github.com/vllm-project/compressed-tensors/pull/489
36+
if scheme.weights.observer in ("minmax", "mse"):
37+
new_observer = f"static_{scheme.weights.observer}"
38+
logger.warning(
39+
f"Scheme uses {scheme.weights.observer} weight observer. "
40+
f"Using {new_observer} instead"
41+
)
42+
scheme.weights.observer = new_observer
43+
44+
# target all modules; filter by ignore list
45+
# technically this should be "re:.*", but vllm's
46+
# ct moe layer has a hard coded check for "Linear"
47+
scheme.targets = ["Linear"]
48+
return scheme_name, scheme
49+
50+
51+
def gpu_if_available(device: torch.device | str | None) -> torch.device:
52+
if device is not None:
53+
return torch.device(device)
54+
55+
elif torch.cuda.is_available():
56+
return torch.device("cuda:0")
57+
58+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
59+
return torch.device("xpu:0")
60+
61+
else:
62+
logger.warning("CUDA/XPU is not available! Compressing model on CPU instead")
63+
return torch.device("cpu")
64+
65+
66+
def is_match_name(
67+
name: str, targets: list[str], ignore: Optional[str | list[str]] = None
68+
) -> bool:
69+
targets = targets if isinstance(targets, list) else [targets]
70+
ignore = ignore if isinstance(ignore, list) else [ignore]
71+
72+
matches_target = any(_match_name(name, target) for target in targets)
73+
matches_ignore = any(_match_name(name, ign) for ign in ignore)
74+
75+
return matches_target and not matches_ignore
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import torch
2+
from compressed_tensors.compressors import BaseCompressor
3+
from compressed_tensors.config.format import _get_quant_compression_format
4+
from compressed_tensors.quantization import (
5+
QuantizationScheme,
6+
QuantizationStrategy,
7+
initialize_module_for_quantization,
8+
)
9+
10+
from llmcompressor.modifiers.quantization.calibration import (
11+
apply_calibration_status,
12+
freeze_module_quantization,
13+
initialize_observer,
14+
update_weight_global_scale,
15+
update_weight_zp_scale,
16+
)
17+
18+
__all__ = [
19+
"initialize_quantized_linear",
20+
"calibrate_weights",
21+
"compress_module",
22+
]
23+
24+
25+
def initialize_quantized_linear(
26+
weight: torch.Tensor, scheme: QuantizationScheme, device: str | torch.device
27+
) -> torch.nn.Module:
28+
out_features, in_features = weight.shape
29+
module = torch.nn.Linear(
30+
in_features, out_features, bias=False, device=device, dtype=weight.dtype
31+
)
32+
module.weight.data.copy_(weight)
33+
initialize_module_for_quantization(module, scheme, force_zero_point=False)
34+
35+
return module
36+
37+
38+
def calibrate_weights(module: torch.nn.Linear):
39+
scheme: QuantizationScheme = getattr(module, "quantization_scheme")
40+
initialize_observer(module, "weight")
41+
42+
apply_calibration_status(module)
43+
if scheme.weights.strategy == QuantizationStrategy.TENSOR_GROUP:
44+
update_weight_global_scale(module)
45+
update_weight_zp_scale(module)
46+
47+
freeze_module_quantization(module)
48+
49+
50+
def compress_module(module: torch.nn.Linear):
51+
scheme: QuantizationScheme = getattr(module, "quantization_scheme")
52+
53+
format = _get_quant_compression_format(scheme.input_activations, scheme.weights)
54+
scheme.format = format.value
55+
56+
compressor = BaseCompressor.load_from_registry(format.value)
57+
data = compressor.compress_weight(
58+
module.weight,
59+
quantization_args=scheme.weights,
60+
scale=getattr(module, "weight_scale"),
61+
zero_point=getattr(module, "weight_zero_point", None),
62+
global_scale=getattr(module, "weight_global_scale", None),
63+
)
64+
65+
# `compress_weight` is a messy api
66+
delattr(module, "weight")
67+
for key, value in data.items():
68+
if hasattr(module, key):
69+
getattr(module, key).data = value
70+
else:
71+
module.register_parameter(
72+
key, torch.nn.Parameter(value, requires_grad=False)
73+
)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import os
2+
3+
from huggingface_hub import list_repo_files
4+
from transformers.utils.hub import cached_file
5+
6+
__all__ = ["get_checkpoint_files", "is_weights_file"]
7+
8+
weights_files = [
9+
".bin",
10+
".safetensors",
11+
".pth",
12+
".msgpack",
13+
".pt",
14+
]
15+
16+
17+
def is_weights_file(file_name: str) -> bool:
18+
return any(file_name.endswith(suffix) for suffix in weights_files)
19+
20+
21+
def get_checkpoint_files(model_stub: str | os.PathLike) -> list[str]:
22+
# In the future, this function can accept and pass download kwargs to cached_file
23+
24+
if os.path.exists(model_stub):
25+
file_paths = walk_file_paths(model_stub, ignore=".cache")
26+
else:
27+
file_paths = list_repo_files(model_stub)
28+
29+
return [(file_path, cached_file(model_stub, file_path)) for file_path in file_paths]
30+
31+
32+
def walk_file_paths(root_dir: str, ignore: str | None = None) -> list[str]:
33+
"""
34+
Return all file paths relative to the root directory
35+
"""
36+
37+
all_files = []
38+
for dirpath, _, filenames in os.walk(root_dir):
39+
for filename in filenames:
40+
rel_path = os.path.relpath(os.path.join(dirpath, filename), root_dir)
41+
if not (ignore and rel_path.startswith(ignore)):
42+
all_files.append(rel_path)
43+
return all_files
44+
45+
46+
# distinguish relative file paths from absolute/resolved file paths
47+
# relative file paths are used to find the save path
48+
# resolved file paths are what are used to load data

0 commit comments

Comments
 (0)