Skip to content

Commit 017fec0

Browse files
committed
Add handling for Qwen3VLMoe on older transformers versions
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
1 parent 560bb9c commit 017fec0

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

src/llmcompressor/modeling/qwen3_vl_moe.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
15
import torch
2-
from transformers import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig
3-
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
4-
Qwen3VLMoeTextSparseMoeBlock as OriginalQwen3VLMoeTextSparseMoeBlock,
5-
)
66

77
from llmcompressor.modeling.moe_context import MoECalibrationModule
88
from llmcompressor.utils.dev import skip_weights_initialize
99

10+
if TYPE_CHECKING:
11+
from transformers import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig
12+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
13+
Qwen3VLMoeTextSparseMoeBlock,
14+
)
15+
1016

1117
@MoECalibrationModule.register("Qwen3VLMoeTextSparseMoeBlock")
1218
class CalibrateQwen3VLMoeTextSparseMoeBlock(MoECalibrationModule):
@@ -19,7 +25,7 @@ class CalibrateQwen3VLMoeTextSparseMoeBlock(MoECalibrationModule):
1925

2026
def __init__(
2127
self,
22-
original: OriginalQwen3VLMoeTextSparseMoeBlock,
28+
original: Qwen3VLMoeTextSparseMoeBlock,
2329
config: Qwen3VLMoeConfig,
2430
calibrate_all_experts: bool,
2531
):
@@ -116,7 +122,7 @@ def __init__(self, config, original):
116122

117123
def replace(
118124
config: Qwen3VLMoeConfig,
119-
original: OriginalQwen3VLMoeTextSparseMoeBlock,
125+
original: Qwen3VLMoeTextSparseMoeBlock,
120126
calibrate_all_experts: bool,
121127
):
122128
return CalibrateQwen3VLMoeTextSparseMoeBlock(

tests/llmcompressor/modeling/test_calib_qwen3_vl_moe.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
1+
import pytest
12
import torch
2-
from transformers import Qwen3VLMoeConfig
3-
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
4-
Qwen3VLMoeTextSparseMoeBlock,
5-
)
63

74
from llmcompressor.modeling.qwen3_vl_moe import CalibrateQwen3VLMoeTextSparseMoeBlock
85
from llmcompressor.utils.helpers import calibration_forward_context
96
from tests.testing_utils import requires_gpu
107

8+
try:
9+
from transformers import Qwen3VLMoeConfig
10+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
11+
Qwen3VLMoeTextSparseMoeBlock,
12+
)
13+
except ImportError:
14+
Qwen3VLMoeConfig = None
15+
Qwen3VLMoeTextSparseMoeBlock = None
16+
1117

18+
@pytest.mark.skipif(
19+
Qwen3VLMoeConfig is None,
20+
reason="Qwen3VLMoe not available in this version of transformers",
21+
)
1222
@requires_gpu
1323
def test_calib_qwen3_vl_moe_module():
1424
config = Qwen3VLMoeConfig()

0 commit comments

Comments
 (0)