Skip to content

Commit 983644a

Browse files
committed
Padded PTQ support for TPU FP4
1 parent 3e3f039 commit 983644a

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

tpu_inference/models/jax/utils/quantization/quantization_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
from flax.typing import PRNGKey
1414
from jax.sharding import Mesh, NamedSharding
1515
from jax.sharding import PartitionSpec as P
16-
from qwix._src.core.qarray import QArray
17-
from qwix._src.providers import ptq
16+
from qwix.contrib.padded_qarray import PaddedQArray as QArray
17+
from qwix.contrib.padded_qarray import PaddedPtqProvider
18+
from qwix.contrib import padded_qarray as ptq
1819

1920
if TYPE_CHECKING:
2021
from vllm.config import VllmConfig
@@ -221,7 +222,7 @@ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
221222
query_start_loc=query_start_loc,
222223
request_distribution=request_distribution),
223224
}
224-
model = qwix.quantize_model(model, qwix.PtqProvider(qwix_rules),
225+
model = qwix.quantize_model(model, PaddedPtqProvider(qwix_rules),
225226
**model_input)
226227
return model
227228

0 commit comments

Comments
 (0)