Skip to content

Commit fb83b7f

Browse files
dsikkaHDCharles
authored andcommitted
update
1 parent 7772e25 commit fb83b7f

File tree

3 files changed

+8
-2
lines changed

3 files changed

+8
-2
lines changed

src/llmcompressor/modeling/moe_context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class MoECalibrationModule(ABC, torch.nn.Module):
4545

4646
is_permanent: bool = False
4747

48-
def restore(self) -> torch.nn.Module:
48+
def restore(self, original: torch.nn.Module) -> torch.nn.Module:
4949
"""
5050
Restore the original module structure.
5151
@@ -163,5 +163,5 @@ def moe_calibration_context(
163163
# Step 2: Restore non-permanent modules
164164
for name, (original, replacement) in replaced.items():
165165
if not replacement.is_permanent:
166-
restored = replacement.restore()
166+
restored = replacement.restore(original)
167167
model.set_submodule(name, restored)

src/llmcompressor/modeling/qwen3_moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ def forward(self, hidden_states: torch.Tensor):
9898
)
9999
return final_hidden_states, router_logits
100100

101+
def restore(self, original: torch.nn.Module) -> torch.nn.Module:
102+
return original
103+
101104

102105
# Legacy function for backward compatibility
103106
def replace(

src/llmcompressor/modeling/qwen3_next_moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
115115
)
116116
return final_hidden_states, router_logits
117117

118+
def restore(self, original: torch.nn.Module) -> torch.nn.Module:
119+
return original
120+
118121

119122
def replace(
120123
config,

0 commit comments

Comments
 (0)