Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
de2b6bd
feat: add optional gradient checkpointing to unet
Sep 3, 2025
66edcb5
fix: small ruff issue
Sep 3, 2025
e66e357
Update monai/networks/nets/unet.py
ferreirafabio80 Sep 4, 2025
feefcaa
docs: update docstrings
Sep 4, 2025
e112457
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 4, 2025
f673ca1
fix: avoid BatchNorm subblocks
Sep 4, 2025
69540ff
fix: revert batch norm changes
Sep 4, 2025
42ec757
refactor: creates a subclass of UNet and overrides the get connection…
Oct 1, 2025
a2e8474
chore: remove use checkpointing from doc string
Oct 1, 2025
4c4782e
fix: linting issues
Oct 2, 2025
515c659
feat: add activation checkpointing to down and up paths to be more ef…
Oct 8, 2025
da5a3a4
refactor: move activation checkpointing wrapper to blocks
Nov 4, 2025
43dec88
chore: add docstrings to checkpointed unet
Nov 4, 2025
84c0f48
test: add checkpoint unet test
Nov 7, 2025
5805515
fix: change test name
Nov 7, 2025
1aa8e3c
fix: simplify test and make sure that checkpoint unet runs well in tr…
Nov 7, 2025
447d9f2
fix: set seed
Nov 7, 2025
b20a19e
fix: fix testing bugs
Nov 7, 2025
41f000f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2025
a068c0e
chore: add test docstrings
Nov 10, 2025
26668cd
DCO Remediation Commit for Fabio Ferreira <f.ferreira@qureight.com>
Nov 10, 2025
814fa80
fix: remove test script save
Nov 13, 2025
c45ee48
fix: tighten tolerance for numerical equivalence
Nov 13, 2025
4349d3f
chore: update doc strings
Nov 14, 2025
885993b
Merge branch 'dev' into feat/add_activation_checkpointing_to_unet
KumoLiu Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions monai/networks/blocks/activation_checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import cast

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint


class ActivationCheckpointWrapper(nn.Module):
"""Wrapper applying activation checkpointing to a module during training.
Args:
module: The module to wrap with activation checkpointing.
"""

def __init__(self, module: nn.Module) -> None:
super().__init__()
self.module = module

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with optional activation checkpointing.
Args:
x: Input tensor.
Returns:
Output tensor from the wrapped module.
"""
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
9 changes: 9 additions & 0 deletions monai/networks/nets/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
import torch.nn as nn

from monai.networks.blocks.activation_checkpointing import ActivationCheckpointWrapper
from monai.networks.blocks.convolutions import Convolution, ResidualUnit
from monai.networks.layers.factories import Act, Norm
from monai.networks.layers.simplelayers import SkipConnection
Expand Down Expand Up @@ -298,4 +299,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class CheckpointUNet(UNet):
def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
subblock = ActivationCheckpointWrapper(subblock)
down_path = ActivationCheckpointWrapper(down_path)
up_path = ActivationCheckpointWrapper(up_path)
return super()._get_connection_block(down_path, up_path, subblock)


Unet = UNet
Loading