|
14 | 14 |
|
15 | 15 | Hacked together by / Copyright 2020 Ross Wightman |
16 | 16 | """ |
| 17 | +from typing import List, Union |
| 18 | + |
17 | 19 | import torch |
18 | 20 | import torch.nn as nn |
19 | 21 | import torch.nn.functional as F |
@@ -180,3 +182,44 @@ def forward(self, x): |
180 | 182 |
|
181 | 183 | def extra_repr(self): |
182 | 184 | return f'drop_prob={round(self.drop_prob,3):0.3f}' |
| 185 | + |
| 186 | + |
| 187 | +def calculate_drop_path_rates( |
| 188 | + drop_path_rate: float, |
| 189 | + depths: Union[int, List[int]], |
| 190 | + stagewise: bool = False, |
| 191 | +) -> Union[List[float], List[List[float]]]: |
| 192 | + """Generate drop path rates for stochastic depth. |
| 193 | +
|
| 194 | + This function handles two common patterns for drop path rate scheduling: |
| 195 | + 1. Per-block: Linear increase from 0 to drop_path_rate across all blocks |
| 196 | + 2. Stage-wise: Linear increase across stages, with same rate within each stage |
| 197 | +
|
| 198 | + Args: |
| 199 | + drop_path_rate: Maximum drop path rate (at the end). |
| 200 | + depths: Either a single int for total depth (per-block mode) or |
| 201 | + list of ints for depths per stage (stage-wise mode). |
| 202 | + stagewise: If True, use stage-wise pattern. If False, use per-block pattern. |
| 203 | + When depths is a list, stagewise defaults to True. |
| 204 | +
|
| 205 | + Returns: |
| 206 | + For per-block mode: List of drop rates, one per block. |
| 207 | + For stage-wise mode: List of lists, drop rates per stage. |
| 208 | + """ |
| 209 | + if isinstance(depths, int): |
| 210 | + # Single depth value - per-block pattern |
| 211 | + if stagewise: |
| 212 | + raise ValueError("stagewise=True requires depths to be a list of stage depths") |
| 213 | + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths, device='cpu')] |
| 214 | + return dpr |
| 215 | + else: |
| 216 | + # List of depths - can be either pattern |
| 217 | + total_depth = sum(depths) |
| 218 | + if stagewise: |
| 219 | + # Stage-wise pattern: same drop rate within each stage |
| 220 | + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, total_depth, device='cpu').split(depths)] |
| 221 | + return dpr |
| 222 | + else: |
| 223 | + # Per-block pattern across all stages |
| 224 | + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth, device='cpu')] |
| 225 | + return dpr |
0 commit comments