Commit f1b8e5a
[Observers] Refactor for better FP4 support, static and memoryless observers (#1903)
* FP4
* Fix bug discovered
[here](#1830 (comment))
where dynamic="local" nvfp4 calculations would increment the observer
twice as fast as normal
* Enable MSE observer to be used with FP4
```psuedocode
mse_quant_error := mean((x - fake_quant(x))**2)
global_scale <- min[min_vals, max_vals,
global_scale](mse_quant_error(x))
scale, zp <- min[min_vals, max_vals](mse_quant_error(x, global_scale))
```
* Simplification
* Make supporting attention calibration easier by separating out
weight/activation/attention reshaping
* Improve readability of observer codes by removing many levels of
function indirection
* Drop support for calibration with non-divisible group sizes. This is
not really a loss, since [forward
passes](https://github.com/neuralmagic/compressed-tensors/blob/main/src/compressed_tensors/quantization/lifecycle/forward.py#L279)
also make this assumption
* New observers
* `memoryless_minmax` computes min and max values on the fly in a
dynamic-quantization style. This observer is useful for PTQ weight
quantization
* `static_minmax` computes absolute min and max values across all
observations. This observer is useful for PTQ activation quantization
* `memoryless_mse` computes best qparams w.r.t. MSE loss for each
observation. This observer is useful for PTQ weight quantization
* Memory improvements
* All observers no longer store copies of scales and zero points,
reducing the amount of required memory
* Newly introduced "memoryless" observers do not store any quantization
parameters, which greatly reduces the memory requirements for PTQ weight
quantization of very large models
| Diagrams |
| - |
| Before |
| <img width="886" height="595" alt="before"
src="https://github.com/user-attachments/assets/660d94c2-3ac8-4e05-9e9b-53d21145abac"
/> |
| After |
<img width="1527" height="595" alt="after"
src="https://github.com/user-attachments/assets/51a0107e-3fbd-413c-a7a6-03ddc3612169"
/> |
* Standardize reshaping using `flatten_for_calibration`
* This function reshapes all observed values to `(num_observations,
*qparams_shape, group_size)`
* This function the complexity associated with passing "reduce dims" and
trying to handle weights, activations, and attention states all in the
same function
* In the future, this function could be applied to the quantization
forward pass, although there's probably no need to outside of
standardization
* Implement `get_global_scale` on `Observer` base
* This function decouples minmax calculations from regular qparam
calculations (avoiding the double increment bug)
* This function enables the MSE observer to be used with FP4 global
scales
* Added additional minmax tests which check exact values of scales. This
test passes both on main and this branch, demonstrating that minmax
observer behavior remains unchanged
* Added additional MSE tests which check exact values of mse losses.
This test passes both on main and this branch, demonstrating that MSE
observer behavior remains unchanged
* Added FP4 MSE test
```
nvfp4-static-minmax
| Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
|--------|------:|------|-----:|--------|---|-----:|---|------|
|mmmu_val| 0|none | 0|mmmu_acc|↑ |0.6167|± | N/A|
```
```
nvfp4-minmax
| Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
|--------|------:|------|-----:|--------|---|-----:|---|------|
|mmmu_val| 0|none | 0|mmmu_acc|↑ |0.6011|± | N/A|
```
---------
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Dan Huang <dan.huang@neuralmagic.com>
Co-authored-by: dhuangnm <74931910+dhuangnm@users.noreply.github.com>1 parent aba933c commit f1b8e5a
1 file changed
+6
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3 | 3 | | |
4 | 4 | | |
5 | 5 | | |
| 6 | + | |
6 | 7 | | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
7 | 13 | | |
8 | 14 | | |
9 | 15 | | |
| |||
0 commit comments