99from .scheduling_utils import SchedulerMixin
1010
1111
12- def gumbel_noise (t , generator = None ):
12+ def gumbel_noise (t : torch .Tensor , generator : Optional [torch .Generator ] = None ) -> torch .Tensor :
13+ """
14+ Generate Gumbel noise for sampling.
15+
16+ Args:
17+ t (`torch.Tensor`):
18+ Input tensor to match the shape and dtype of the output noise.
19+ generator (`torch.Generator`, *optional*):
20+ A random number generator for reproducible sampling.
21+
22+ Returns:
23+ `torch.Tensor`:
24+ Gumbel-distributed noise with the same shape as the input tensor.
25+ """
1326 device = generator .device if generator is not None else t .device
1427 noise = torch .zeros_like (t , device = device ).uniform_ (0 , 1 , generator = generator ).to (t .device )
1528 return - torch .log ((- torch .log (noise .clamp (1e-20 ))).clamp (1e-20 ))
1629
1730
18- def mask_by_random_topk (mask_len , probs , temperature = 1.0 , generator = None ):
31+ def mask_by_random_topk (
32+ mask_len : torch .Tensor ,
33+ probs : torch .Tensor ,
34+ temperature : float = 1.0 ,
35+ generator : Optional [torch .Generator ] = None ,
36+ ) -> torch .Tensor :
37+ """
38+ Mask tokens by selecting the top-k lowest confidence scores with temperature-based randomness.
39+
40+ Args:
41+ mask_len (`torch.Tensor`):
42+ Number of tokens to mask per sample in the batch.
43+ probs (`torch.Tensor`):
44+ Probability scores for each token.
45+ temperature (`float`, *optional*, defaults to 1.0):
46+ Temperature parameter for controlling randomness in the masking process.
47+ generator (`torch.Generator`, *optional*):
48+ A random number generator for reproducible sampling.
49+
50+ Returns:
51+ `torch.Tensor`:
52+ Boolean mask indicating which tokens should be masked.
53+ """
1954 confidence = torch .log (probs .clamp (1e-20 )) + temperature * gumbel_noise (probs , generator = generator )
2055 sorted_confidence = torch .sort (confidence , dim = - 1 ).values
2156 cut_off = torch .gather (sorted_confidence , 1 , mask_len .long ())
@@ -38,20 +73,31 @@ class AmusedSchedulerOutput(BaseOutput):
3873 """
3974
4075 prev_sample : torch .Tensor
41- pred_original_sample : torch .Tensor = None
76+ pred_original_sample : Optional [ torch .Tensor ] = None
4277
4378
4479class AmusedScheduler (SchedulerMixin , ConfigMixin ):
4580 order = 1
4681
47- temperatures : torch .Tensor
82+ temperatures : Optional [torch .Tensor ]
83+ timesteps : Optional [torch .Tensor ]
4884
4985 @register_to_config
5086 def __init__ (
5187 self ,
5288 mask_token_id : int ,
5389 masking_schedule : str = "cosine" ,
5490 ):
91+ """
92+ Create a new AmusedScheduler instance.
93+
94+ Args:
95+ mask_token_id (`int`):
96+ The token ID used to represent masked tokens in the sequence.
97+ masking_schedule (`str`, *optional*, defaults to `"cosine"`):
98+ The schedule type for determining the mask ratio at each timestep. Can be either `"cosine"` or
99+ `"linear"`.
100+ """
55101 self .temperatures = None
56102 self .timesteps = None
57103
@@ -60,7 +106,21 @@ def set_timesteps(
60106 num_inference_steps : int ,
61107 temperature : Union [int , Tuple [int , int ], List [int ]] = (2 , 0 ),
62108 device : Union [str , torch .device ] = None ,
63- ):
109+ ) -> None :
110+ """
111+ Set the discrete timesteps used for the diffusion chain.
112+
113+ Args:
114+ num_inference_steps (`int`):
115+ The number of diffusion steps used when generating samples with a pre-trained model.
116+ temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to `(2, 0)`):
117+ Temperature parameter(s) for controlling the randomness of sampling. If a tuple or list is provided,
118+ temperatures will be linearly interpolated between the first and second values across all timesteps. If
119+ a single value is provided, temperatures will be linearly interpolated from that value to 0.01.
120+ device (`Union[str, torch.device]`, *optional*):
121+ The device to which the timesteps and temperatures should be moved. If not specified, uses the default
122+ device.
123+ """
64124 self .timesteps = torch .arange (num_inference_steps , device = device ).flip (0 )
65125
66126 if isinstance (temperature , (tuple , list )):
@@ -71,12 +131,38 @@ def set_timesteps(
71131 def step (
72132 self ,
73133 model_output : torch .Tensor ,
74- timestep : torch . long ,
134+ timestep : int ,
75135 sample : torch .LongTensor ,
76- starting_mask_ratio : int = 1 ,
136+ starting_mask_ratio : float = 1.0 ,
77137 generator : Optional [torch .Generator ] = None ,
78138 return_dict : bool = True ,
79139 ) -> Union [AmusedSchedulerOutput , Tuple ]:
140+ """
141+ Predict the sample at the previous timestep by masking tokens based on confidence scores.
142+
143+ Args:
144+ model_output (`torch.Tensor`):
145+ The direct output from the learned diffusion model. Typically of shape `(batch_size, num_tokens,
146+ codebook_size)` or `(batch_size, codebook_size, height, width)` for 2D inputs.
147+ timestep (`int`):
148+ The current discrete timestep in the diffusion chain.
149+ sample (`torch.LongTensor`):
150+ A current instance of a sample created by the diffusion process. Contains token IDs, with masked
151+ positions indicated by `mask_token_id`.
152+ starting_mask_ratio (`float`, *optional*, defaults to 1.0):
153+ A multiplier applied to the mask ratio schedule. Values less than 1.0 will result in fewer tokens being
154+ masked at each step.
155+ generator (`torch.Generator`, *optional*):
156+ A random number generator for reproducible sampling.
157+ return_dict (`bool`, *optional*, defaults to `True`):
158+ Whether to return an [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] or a plain tuple.
159+
160+ Returns:
161+ [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] or `tuple`:
162+ If `return_dict` is `True`, returns [`~schedulers.scheduling_amused.AmusedSchedulerOutput`], otherwise
163+ returns a tuple where the first element is the sample tensor and the second element is the predicted
164+ original sample tensor.
165+ """
80166 two_dim_input = sample .ndim == 3 and model_output .ndim == 4
81167
82168 if two_dim_input :
@@ -137,7 +223,27 @@ def step(
137223
138224 return AmusedSchedulerOutput (prev_sample , pred_original_sample )
139225
140- def add_noise (self , sample , timesteps , generator = None ):
226+ def add_noise (
227+ self ,
228+ sample : torch .LongTensor ,
229+ timesteps : int ,
230+ generator : Optional [torch .Generator ] = None ,
231+ ) -> torch .LongTensor :
232+ """
233+ Add noise to a sample by randomly masking tokens according to the masking schedule.
234+
235+ Args:
236+ sample (`torch.LongTensor`):
237+ The input sample containing token IDs to be partially masked.
238+ timesteps (`int`):
239+ The timestep that determines how much masking to apply. Higher timesteps result in more masking.
240+ generator (`torch.Generator`, *optional*):
241+ A random number generator for reproducible masking.
242+
243+ Returns:
244+ `torch.LongTensor`:
245+ The sample with some tokens replaced by `mask_token_id` according to the masking schedule.
246+ """
141247 step_idx = (self .timesteps == timesteps ).nonzero ()
142248 ratio = (step_idx + 1 ) / len (self .timesteps )
143249
0 commit comments