Skip to content

Commit a88a404

Browse files
authored
[BUG] fixed memory leak in BaseModel by detach some tensor (#1924)
#### Reference Issues/PRs #1369 #1461 #### What does this implement/fix? Explain your changes. 1.Detached tensors in the log dictionary before appending them to the training/validation/testing_step_outputs lists. This fixes a memory leak caused by retaining the computation graph for every batch throughout an entire epoch. 2.Detached the loss tensor within the step() method before logging. 3.Move prediction results to CPU to prevent VRAM growth. #### Did you add any tests for the change? I ran my training code for 5 epochs using a memory profiler. Here are two comparison plot: before <img width="1156" height="472" alt="before" src="https://github.com/user-attachments/assets/45a6696a-efe6-4f06-897e-d80daee79977" /> after <img width="1132" height="470" alt="alfter" src="https://github.com/user-attachments/assets/c3ba5187-97ee-4b6b-b4cb-2d641b2d0d88" />
1 parent 81b5303 commit a88a404

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

pytorch_forecasting/models/base/_base_model.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,10 @@
5454
apply_to_list,
5555
concat_sequences,
5656
create_mask,
57+
detach,
5758
get_embedding_size,
5859
groupby_apply,
60+
move_to_device,
5961
to_list,
6062
)
6163
from pytorch_forecasting.utils._classproperty import classproperty
@@ -308,6 +310,8 @@ def on_predict_batch_end(
308310
else:
309311
raise ValueError(f"Unknown mode {self.mode} - see docs for valid arguments")
310312

313+
out = move_to_device(detach(out), "cpu")
314+
x = move_to_device(detach(x), "cpu")
311315
self._output.append(out)
312316
out = dict(output=out)
313317
if self.return_x:
@@ -720,7 +724,7 @@ def training_step(self, batch, batch_idx):
720724
"""
721725
x, y = batch
722726
log, out = self.step(x, y, batch_idx)
723-
self.training_step_outputs.append(log)
727+
self.training_step_outputs.append(detach(log))
724728
return log
725729

726730
def on_train_epoch_end(self):
@@ -739,7 +743,7 @@ def validation_step(self, batch, batch_idx):
739743
x, y = batch
740744
log, out = self.step(x, y, batch_idx)
741745
log.update(self.create_log(x, y, out, batch_idx))
742-
self.validation_step_outputs.append(log)
746+
self.validation_step_outputs.append(detach(log))
743747
return log
744748

745749
def on_validation_epoch_end(self):
@@ -750,7 +754,7 @@ def test_step(self, batch, batch_idx):
750754
x, y = batch
751755
log, out = self.step(x, y, batch_idx)
752756
log.update(self.create_log(x, y, out, batch_idx))
753-
self.testing_step_outputs.append(log)
757+
self.testing_step_outputs.append(detach(log))
754758
return log
755759

756760
def on_test_epoch_end(self):
@@ -934,7 +938,7 @@ def step(
934938
loss.requires_grad_(True)
935939
self.log(
936940
f"{self.current_stage}_loss",
937-
loss,
941+
detach(loss),
938942
on_step=self.training,
939943
on_epoch=True,
940944
prog_bar=True,

0 commit comments

Comments
 (0)