Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 60a3851

Browse files
Warvitoericspod
andauthored
Remove asserts (#383)
* Remove asserts Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com> * Update generative/networks/nets/diffusion_model_unet.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com> * Edit docstring Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com> * Fix test Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com> --------- Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 639b6eb commit 60a3851

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

generative/networks/nets/diffusion_model_unet.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,8 @@ def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_peri
468468
embedding_dim: the dimension of the output.
469469
max_period: controls the minimum frequency of the embeddings.
470470
"""
471-
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
471+
if timesteps.ndim != 1:
472+
raise ValueError("Timesteps should be a 1d-array")
472473

473474
half_dim = embedding_dim // 2
474475
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
@@ -491,7 +492,8 @@ class Downsample(nn.Module):
491492
Args:
492493
spatial_dims: number of spatial dimensions.
493494
num_channels: number of input channels.
494-
use_conv: if True uses Convolution instead of Pool average to perform downsampling.
495+
use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is
496+
False, the number of output channels must be the same as the number of input channels.
495497
out_channels: number of output channels.
496498
padding: controls the amount of implicit zero-paddings on both sides for padding number of points
497499
for each dimension.
@@ -515,12 +517,17 @@ def __init__(
515517
conv_only=True,
516518
)
517519
else:
518-
assert self.num_channels == self.out_channels
520+
if self.num_channels != self.out_channels:
521+
raise ValueError("num_channels and out_channels must be equal when use_conv=False")
519522
self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2)
520523

521524
def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
522525
del emb
523-
assert x.shape[1] == self.num_channels
526+
if x.shape[1] != self.num_channels:
527+
raise ValueError(
528+
f"Input number of channels ({x.shape[1]}) is not equal to expected number of channels "
529+
f"({self.num_channels})"
530+
)
524531
return self.op(x)
525532

526533

@@ -559,7 +566,8 @@ def __init__(
559566

560567
def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
561568
del emb
562-
assert x.shape[1] == self.num_channels
569+
if x.shape[1] != self.num_channels:
570+
raise ValueError("Input channels should be equal to num_channels")
563571

564572
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
565573
# https://github.com/pytorch/pytorch/issues/86679

tests/test_diffusion_model_unet.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,20 @@ def test_shape_unconditioned_models(self, input_param):
240240
result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long())
241241
self.assertEqual(result.shape, (1, 1, 16, 16))
242242

243+
def test_timestep_with_wrong_shape(self):
244+
net = DiffusionModelUNet(
245+
spatial_dims=2,
246+
in_channels=1,
247+
out_channels=1,
248+
num_res_blocks=1,
249+
num_channels=(8, 8, 8),
250+
attention_levels=(False, False, False),
251+
norm_num_groups=8,
252+
)
253+
with self.assertRaises(ValueError):
254+
with eval_mode(net):
255+
net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long())
256+
243257
def test_shape_with_different_in_channel_out_channel(self):
244258
in_channels = 6
245259
out_channels = 3

0 commit comments

Comments
 (0)