Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 5e3fe02

Browse files
authored
Fix grad strides warning when using ddp (#375)
* Check for xformers install when using flash attention in diffusion unet * Ensure tensors are contiguous * Formatting fix
1 parent 4aee21a commit 5e3fe02

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

generative/networks/nets/diffusion_model_unet.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,9 +334,9 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch
334334
x = block(x, context=context)
335335

336336
if self.spatial_dims == 2:
337-
x = x.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2)
337+
x = x.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
338338
if self.spatial_dims == 3:
339-
x = x.reshape(batch, height, width, depth, inner_dim).permute(0, 4, 1, 2, 3)
339+
x = x.reshape(batch, height, width, depth, inner_dim).permute(0, 4, 1, 2, 3).contiguous()
340340

341341
x = self.proj_out(x)
342342
return x + residual
@@ -1702,6 +1702,9 @@ def __init__(
17021702
"`num_channels`."
17031703
)
17041704

1705+
if use_flash_attention and not has_xformers:
1706+
raise ValueError("use_flash_attention is True but xformers is not installed.")
1707+
17051708
if use_flash_attention is True and not torch.cuda.is_available():
17061709
raise ValueError(
17071710
"torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU."

0 commit comments

Comments
 (0)