Skip to content

Commit 3f58765

Browse files
committed
Add calculate_drop_path_rates helper. Make force calcs on cpu device to avoid issue with device contexts/meta device use. Fix a few small inconsistencies that were noticed.
1 parent a510490 commit 3f58765

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+153
-104
lines changed

timm/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from .create_conv2d import create_conv2d
4343
from .create_norm import get_norm_layer, create_norm_layer
4444
from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
45-
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
45+
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path, calculate_drop_path_rates
4646
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
4747
from .evo_norm import (
4848
EvoNorm2dB0,

timm/layers/drop.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
1515
Hacked together by / Copyright 2020 Ross Wightman
1616
"""
17+
from typing import List, Union
18+
1719
import torch
1820
import torch.nn as nn
1921
import torch.nn.functional as F
@@ -180,3 +182,44 @@ def forward(self, x):
180182

181183
def extra_repr(self):
182184
return f'drop_prob={round(self.drop_prob,3):0.3f}'
185+
186+
187+
def calculate_drop_path_rates(
188+
drop_path_rate: float,
189+
depths: Union[int, List[int]],
190+
stagewise: bool = False,
191+
) -> Union[List[float], List[List[float]]]:
192+
"""Generate drop path rates for stochastic depth.
193+
194+
This function handles two common patterns for drop path rate scheduling:
195+
1. Per-block: Linear increase from 0 to drop_path_rate across all blocks
196+
2. Stage-wise: Linear increase across stages, with same rate within each stage
197+
198+
Args:
199+
drop_path_rate: Maximum drop path rate (at the end).
200+
depths: Either a single int for total depth (per-block mode) or
201+
list of ints for depths per stage (stage-wise mode).
202+
stagewise: If True, use stage-wise pattern. If False, use per-block pattern.
203+
When depths is a list, stagewise defaults to True.
204+
205+
Returns:
206+
For per-block mode: List of drop rates, one per block.
207+
For stage-wise mode: List of lists, drop rates per stage.
208+
"""
209+
if isinstance(depths, int):
210+
# Single depth value - per-block pattern
211+
if stagewise:
212+
raise ValueError("stagewise=True requires depths to be a list of stage depths")
213+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths, device='cpu')]
214+
return dpr
215+
else:
216+
# List of depths - can be either pattern
217+
total_depth = sum(depths)
218+
if stagewise:
219+
# Stage-wise pattern: same drop rate within each stage
220+
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, total_depth, device='cpu').split(depths)]
221+
return dpr
222+
else:
223+
# Per-block pattern across all stages
224+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth, device='cpu')]
225+
return dpr

timm/models/beit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
import torch.nn.functional as F
4747

4848
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
49-
from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn
49+
from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, calculate_drop_path_rates, trunc_normal_, use_fused_attn
5050
from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table, ndgrid
5151

5252
from ._builder import build_model_with_cfg
@@ -448,7 +448,7 @@ def __init__(
448448
else:
449449
self.rel_pos_bias = None
450450

451-
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
451+
dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule
452452
self.blocks = nn.ModuleList([
453453
Block(
454454
dim=embed_dim,

timm/models/byobnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
4040
from timm.layers import (
4141
ClassifierHead, NormMlpClassifierHead, ConvNormAct, BatchNormAct2d, EvoNorm2dS0a,
42-
AttentionPool2d, RotAttentionPool2d, DropPath, AvgPool2dSame,
42+
AttentionPool2d, RotAttentionPool2d, DropPath, calculate_drop_path_rates, AvgPool2dSame,
4343
create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple,
4444
)
4545
from ._builder import build_model_with_cfg
@@ -1212,7 +1212,7 @@ def create_byob_stages(
12121212
feature_info = []
12131213
block_cfgs = [expand_blocks_cfg(s) for s in cfg.blocks]
12141214
depths = [sum([bc.d for bc in stage_bcs]) for stage_bcs in block_cfgs]
1215-
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
1215+
dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
12161216
dilation = 1
12171217
net_stride = stem_feat['reduction']
12181218
prev_chs = stem_feat['num_chs']

timm/models/coat.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,9 +417,7 @@ def __init__(
417417
self.crpe3 = ConvRelPosEnc(head_chs=embed_dims[2] // num_heads, num_heads=num_heads, window=crpe_window)
418418
self.crpe4 = ConvRelPosEnc(head_chs=embed_dims[3] // num_heads, num_heads=num_heads, window=crpe_window)
419419

420-
# Disable stochastic depth.
421420
dpr = drop_path_rate
422-
assert dpr == 0.0
423421
skwargs = dict(
424422
num_heads=num_heads,
425423
qkv_bias=qkv_bias,

timm/models/convit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import torch.nn as nn
2828

2929
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
30-
from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp, LayerNorm, HybridEmbed
30+
from timm.layers import DropPath, calculate_drop_path_rates, trunc_normal_, PatchEmbed, Mlp, LayerNorm, HybridEmbed
3131
from ._builder import build_model_with_cfg
3232
from ._features_fx import register_notrace_module
3333
from ._registry import register_model, generate_default_cfgs
@@ -292,7 +292,7 @@ def __init__(
292292
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
293293
trunc_normal_(self.pos_embed, std=.02)
294294

295-
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
295+
dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule
296296
self.blocks = nn.ModuleList([
297297
Block(
298298
dim=embed_dim,

timm/models/convnext.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
import torch.nn as nn
4545

4646
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
47-
from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
47+
from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, calculate_drop_path_rates, Mlp, GlobalResponseNormMlp, \
4848
LayerNorm2d, LayerNorm, RmsNorm2d, RmsNorm, create_conv2d, get_act_layer, get_norm_layer, make_divisible, to_ntuple
4949
from timm.layers import SimpleNorm2d, SimpleNorm
5050
from timm.layers import NormMlpClassifierHead, ClassifierHead
@@ -377,7 +377,7 @@ def __init__(
377377
stem_stride = 4
378378

379379
self.stages = nn.Sequential()
380-
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
380+
dp_rates = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
381381
stages = []
382382
prev_chs = dims[0]
383383
curr_stride = stem_stride

timm/models/crossvit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import torch.nn as nn
2828

2929
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
30-
from timm.layers import DropPath, to_2tuple, trunc_normal_, _assert
30+
from timm.layers import DropPath, calculate_drop_path_rates, to_2tuple, trunc_normal_, _assert
3131
from ._builder import build_model_with_cfg
3232
from ._features_fx import register_notrace_function
3333
from ._registry import register_model, generate_default_cfgs
@@ -346,7 +346,7 @@ def __init__(
346346
self.pos_drop = nn.Dropout(p=pos_drop_rate)
347347

348348
total_depth = sum([sum(x[-2:]) for x in depth])
349-
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)] # stochastic depth decay rule
349+
dpr = calculate_drop_path_rates(drop_path_rate, total_depth) # stochastic depth decay rule
350350
dpr_ptr = 0
351351
self.blocks = nn.ModuleList()
352352
for idx, block_cfg in enumerate(depth):

timm/models/cspnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch.nn as nn
2121

2222
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
23-
from timm.layers import ClassifierHead, ConvNormAct, DropPath, get_attn, create_act_layer, make_divisible
23+
from timm.layers import ClassifierHead, ConvNormAct, DropPath, calculate_drop_path_rates, get_attn, create_act_layer, make_divisible
2424
from ._builder import build_model_with_cfg
2525
from ._manipulate import named_apply, MATCH_PREV_GROUP
2626
from ._registry import register_model, generate_default_cfgs
@@ -569,7 +569,7 @@ def create_csp_stages(
569569
cfg_dict = asdict(cfg.stages)
570570
num_stages = len(cfg.stages.depth)
571571
cfg_dict['block_dpr'] = [None] * num_stages if not drop_path_rate else \
572-
[x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.stages.depth)).split(cfg.stages.depth)]
572+
calculate_drop_path_rates(drop_path_rate, cfg.stages.depth, stagewise=True)
573573
stage_args = [dict(zip(cfg_dict.keys(), values)) for values in zip(*cfg_dict.values())]
574574
block_kwargs = dict(
575575
act_layer=cfg.act_layer,

timm/models/davit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from torch import Tensor
2121

2222
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
23-
from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn
23+
from timm.layers import DropPath, calculate_drop_path_rates, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn
2424
from timm.layers import NormMlpClassifierHead, ClassifierHead
2525
from ._builder import build_model_with_cfg
2626
from ._features import feature_take_indices
@@ -555,7 +555,7 @@ def __init__(
555555
self.stem = Stem(in_chans, embed_dims[0], norm_layer=norm_layer)
556556
in_chs = embed_dims[0]
557557

558-
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
558+
dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
559559
stages = []
560560
for i in range(num_stages):
561561
out_chs = embed_dims[i]

0 commit comments

Comments
 (0)