Skip to content

Commit 8d366bd

Browse files
authored
[Transform] SpinQuant fix OOM (#1976)
SUMMARY: "When using SpinQuantModifier for some fuse operations, it is necessary to add the torch.no_grad decorator. Otherwise, PyTorch will capture the grad graph by default, leading to a gradual increase in memory usage. I encountered a CUDA OOM issue when rotating the MOE model, and the OOM problem was resolved after fixing it." TEST PLAN: "Performed code quality evaluation locally" Signed-off-by: LeiZhang <isleizhang@outlook.com>
1 parent 296d48f commit 8d366bd

File tree

1 file changed

+1
-0
lines changed
  • src/llmcompressor/modifiers/transform/spinquant

1 file changed

+1
-0
lines changed

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def on_initialize(self, state: State, **kwargs) -> bool:
144144

145145
return True
146146

147+
@torch.no_grad()
147148
def on_start(self, state: State, event: Event, **kwargs):
148149
self.started_ = True
149150

0 commit comments

Comments
 (0)