Skip to content

Commit b21f59a

Browse files
committed
Improve docstrings and type hints in scheduling_amused.py
- Add complete type hints for helper functions (gumbel_noise, mask_by_random_topk) - Enhance AmusedSchedulerOutput with proper Optional typing - Add comprehensive docstrings for AmusedScheduler class - Improve __init__, set_timesteps, step, and add_noise methods - Fix type hints to match documentation conventions - All changes follow project standards from issue #9567
1 parent 67f931b commit b21f59a

File tree

1 file changed

+114
-8
lines changed

1 file changed

+114
-8
lines changed

src/diffusers/schedulers/scheduling_amused.py

Lines changed: 114 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,48 @@
99
from .scheduling_utils import SchedulerMixin
1010

1111

12-
def gumbel_noise(t, generator=None):
12+
def gumbel_noise(t: torch.Tensor, generator: Optional[torch.Generator] = None) -> torch.Tensor:
13+
"""
14+
Generate Gumbel noise for sampling.
15+
16+
Args:
17+
t (`torch.Tensor`):
18+
Input tensor to match the shape and dtype of the output noise.
19+
generator (`torch.Generator`, *optional*):
20+
A random number generator for reproducible sampling.
21+
22+
Returns:
23+
`torch.Tensor`:
24+
Gumbel-distributed noise with the same shape as the input tensor.
25+
"""
1326
device = generator.device if generator is not None else t.device
1427
noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device)
1528
return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20))
1629

1730

18-
def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
31+
def mask_by_random_topk(
32+
mask_len: torch.Tensor,
33+
probs: torch.Tensor,
34+
temperature: float = 1.0,
35+
generator: Optional[torch.Generator] = None,
36+
) -> torch.Tensor:
37+
"""
38+
Mask tokens by selecting the top-k lowest confidence scores with temperature-based randomness.
39+
40+
Args:
41+
mask_len (`torch.Tensor`):
42+
Number of tokens to mask per sample in the batch.
43+
probs (`torch.Tensor`):
44+
Probability scores for each token.
45+
temperature (`float`, *optional*, defaults to 1.0):
46+
Temperature parameter for controlling randomness in the masking process.
47+
generator (`torch.Generator`, *optional*):
48+
A random number generator for reproducible sampling.
49+
50+
Returns:
51+
`torch.Tensor`:
52+
Boolean mask indicating which tokens should be masked.
53+
"""
1954
confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator)
2055
sorted_confidence = torch.sort(confidence, dim=-1).values
2156
cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
@@ -38,20 +73,31 @@ class AmusedSchedulerOutput(BaseOutput):
3873
"""
3974

4075
prev_sample: torch.Tensor
41-
pred_original_sample: torch.Tensor = None
76+
pred_original_sample: Optional[torch.Tensor] = None
4277

4378

4479
class AmusedScheduler(SchedulerMixin, ConfigMixin):
4580
order = 1
4681

47-
temperatures: torch.Tensor
82+
temperatures: Optional[torch.Tensor]
83+
timesteps: Optional[torch.Tensor]
4884

4985
@register_to_config
5086
def __init__(
5187
self,
5288
mask_token_id: int,
5389
masking_schedule: str = "cosine",
5490
):
91+
"""
92+
Create a new AmusedScheduler instance.
93+
94+
Args:
95+
mask_token_id (`int`):
96+
The token ID used to represent masked tokens in the sequence.
97+
masking_schedule (`str`, *optional*, defaults to `"cosine"`):
98+
The schedule type for determining the mask ratio at each timestep. Can be either `"cosine"` or
99+
`"linear"`.
100+
"""
55101
self.temperatures = None
56102
self.timesteps = None
57103

@@ -60,7 +106,21 @@ def set_timesteps(
60106
num_inference_steps: int,
61107
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
62108
device: Union[str, torch.device] = None,
63-
):
109+
) -> None:
110+
"""
111+
Set the discrete timesteps used for the diffusion chain.
112+
113+
Args:
114+
num_inference_steps (`int`):
115+
The number of diffusion steps used when generating samples with a pre-trained model.
116+
temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to `(2, 0)`):
117+
Temperature parameter(s) for controlling the randomness of sampling. If a tuple or list is provided,
118+
temperatures will be linearly interpolated between the first and second values across all timesteps. If
119+
a single value is provided, temperatures will be linearly interpolated from that value to 0.01.
120+
device (`Union[str, torch.device]`, *optional*):
121+
The device to which the timesteps and temperatures should be moved. If not specified, uses the default
122+
device.
123+
"""
64124
self.timesteps = torch.arange(num_inference_steps, device=device).flip(0)
65125

66126
if isinstance(temperature, (tuple, list)):
@@ -71,12 +131,38 @@ def set_timesteps(
71131
def step(
72132
self,
73133
model_output: torch.Tensor,
74-
timestep: torch.long,
134+
timestep: int,
75135
sample: torch.LongTensor,
76-
starting_mask_ratio: int = 1,
136+
starting_mask_ratio: float = 1.0,
77137
generator: Optional[torch.Generator] = None,
78138
return_dict: bool = True,
79139
) -> Union[AmusedSchedulerOutput, Tuple]:
140+
"""
141+
Predict the sample at the previous timestep by masking tokens based on confidence scores.
142+
143+
Args:
144+
model_output (`torch.Tensor`):
145+
The direct output from the learned diffusion model. Typically of shape `(batch_size, num_tokens,
146+
codebook_size)` or `(batch_size, codebook_size, height, width)` for 2D inputs.
147+
timestep (`int`):
148+
The current discrete timestep in the diffusion chain.
149+
sample (`torch.LongTensor`):
150+
A current instance of a sample created by the diffusion process. Contains token IDs, with masked
151+
positions indicated by `mask_token_id`.
152+
starting_mask_ratio (`float`, *optional*, defaults to 1.0):
153+
A multiplier applied to the mask ratio schedule. Values less than 1.0 will result in fewer tokens being
154+
masked at each step.
155+
generator (`torch.Generator`, *optional*):
156+
A random number generator for reproducible sampling.
157+
return_dict (`bool`, *optional*, defaults to `True`):
158+
Whether to return an [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] or a plain tuple.
159+
160+
Returns:
161+
[`~schedulers.scheduling_amused.AmusedSchedulerOutput`] or `tuple`:
162+
If `return_dict` is `True`, returns [`~schedulers.scheduling_amused.AmusedSchedulerOutput`], otherwise
163+
returns a tuple where the first element is the sample tensor and the second element is the predicted
164+
original sample tensor.
165+
"""
80166
two_dim_input = sample.ndim == 3 and model_output.ndim == 4
81167

82168
if two_dim_input:
@@ -137,7 +223,27 @@ def step(
137223

138224
return AmusedSchedulerOutput(prev_sample, pred_original_sample)
139225

140-
def add_noise(self, sample, timesteps, generator=None):
226+
def add_noise(
227+
self,
228+
sample: torch.LongTensor,
229+
timesteps: int,
230+
generator: Optional[torch.Generator] = None,
231+
) -> torch.LongTensor:
232+
"""
233+
Add noise to a sample by randomly masking tokens according to the masking schedule.
234+
235+
Args:
236+
sample (`torch.LongTensor`):
237+
The input sample containing token IDs to be partially masked.
238+
timesteps (`int`):
239+
The timestep that determines how much masking to apply. Higher timesteps result in more masking.
240+
generator (`torch.Generator`, *optional*):
241+
A random number generator for reproducible masking.
242+
243+
Returns:
244+
`torch.LongTensor`:
245+
The sample with some tokens replaced by `mask_token_id` according to the masking schedule.
246+
"""
141247
step_idx = (self.timesteps == timesteps).nonzero()
142248
ratio = (step_idx + 1) / len(self.timesteps)
143249

0 commit comments

Comments
 (0)