@@ -127,15 +127,15 @@ def ddpm_denoise_sample(
127127 else :
128128 x_t = np .random .normal (size = orig_x .shape )
129129
130- x_t = nnet .tensor (x_t , requires_grad = False , device = device )
130+ x_t = nnet .tensor (x_t , device = device )
131131 x_ts = []
132132 for t in tqdm (
133133 reversed (range (0 , self .timesteps )),
134134 desc = "ddpm denoisinig samples" ,
135135 total = self .timesteps ,
136136 ):
137137 noise = (
138- nnet .tensor (np .random .normal (size = x_t .shape ), requires_grad = False , device = device )
138+ nnet .tensor (np .random .normal (size = x_t .shape ), device = device )
139139 if t > 1
140140 else 0
141141 )
@@ -152,7 +152,6 @@ def ddpm_denoise_sample(
152152 if mask is not None :
153153 orig_x_noise = nnet .tensor (
154154 np .random .normal (size = orig_x .shape ),
155- requires_grad = False ,
156155 device = device ,
157156 )
158157
@@ -196,32 +195,40 @@ def ddim_denoise_sample(
196195 else :
197196 x_t = np .random .normal (size = orig_x .shape )
198197
198+ x_t = nnet .tensor (x_t , device = device )
199199 x_ts = []
200200 for t in tqdm (
201201 reversed (range (1 , self .timesteps )[:perform_steps ]),
202202 desc = "ddim denoisinig samples" ,
203203 total = perform_steps ,
204204 ):
205- noise = np .random .normal (size = x_t .shape ) if t > 1 else 0
206- eps = self .model .forward (x_t , np .array ([t ]) / self .timesteps , training = False ).reshape (
207- x_t .shape
205+ noise = (
206+ nnet .tensor (np .random .normal (size = x_t .shape ), device = device )
207+ if t > 1
208+ else 0
208209 )
210+ eps = self .model .forward (x_t , np .array ([t ]) / self .timesteps ).reshape (
211+ x_t .shape
212+ ).detach ()
209213
210- x0_t = (x_t - eps * np .sqrt (1 - self .alphas_cumprod [t ])) / np .sqrt (
214+ x0_t = (x_t - eps * nnet .sqrt (1 - self .alphas_cumprod [t ])) / nnet .sqrt (
211215 self .alphas_cumprod [t ]
212216 )
213217
214- sigma = eta * np .sqrt (
218+ sigma = eta * nnet .sqrt (
215219 (1 - self .alphas_cumprod [t - 1 ])
216220 / (1 - self .alphas_cumprod [t ])
217221 * (1 - self .alphas_cumprod [t ] / self .alphas_cumprod [t - 1 ])
218222 )
219- c = np .sqrt ((1 - self .alphas_cumprod [t - 1 ]) - sigma ** 2 )
223+ c = nnet .sqrt ((1 - self .alphas_cumprod [t - 1 ]) - sigma ** 2 )
220224
221- x_t = np .sqrt (self .alphas_cumprod [t - 1 ]) * x0_t - c * eps + sigma * noise
225+ x_t = nnet .sqrt (self .alphas_cumprod [t - 1 ]) * x0_t - c * eps + sigma * noise
222226
223227 if mask is not None :
224- orig_x_noise = np .random .normal (size = orig_x .shape )
228+ orig_x_noise = nnet .tensor (
229+ np .random .normal (size = orig_x .shape ),
230+ device = device ,
231+ )
225232
226233 orig_x_t = (
227234 self .sqrt_alphas_cumprod [t ] * orig_x
@@ -230,9 +237,9 @@ def ddim_denoise_sample(
230237 x_t = orig_x_t * mask + x_t * (1 - mask )
231238
232239 if t % states_step_size == 0 :
233- x_ts .append (x_t )
240+ x_ts .append (x_t . cpu (). detach (). numpy () )
234241
235- return x_t , x_ts
242+ return x_t . to ( "cpu" ). detach (). numpy () , x_ts
236243
237244 def get_images_set (
238245 self ,
0 commit comments