1010from torchvision import datasets , transforms
1111from torch .utils .data import DataLoader
1212
13- from vector_quantize_pytorch import LFQ
13+ from vector_quantize_pytorch import LFQ , Sequential
1414
1515lr = 3e-4
1616train_iter = 1000
2222
2323device = "cuda" if torch .cuda .is_available () else "cpu"
2424
25- class LFQAutoEncoder (nn .Module ):
26- def __init__ (
27- self ,
28- codebook_size ,
29- ** vq_kwargs
30- ):
31- super ().__init__ ()
32- assert log2 (codebook_size ).is_integer ()
33- quantize_dim = int (log2 (codebook_size ))
34-
35- self .encode = nn .Sequential (
36- nn .Conv2d (1 , 16 , kernel_size = 3 , stride = 1 , padding = 1 ),
37- nn .MaxPool2d (kernel_size = 2 , stride = 2 ),
38- nn .GELU (),
39- nn .Conv2d (16 , 32 , kernel_size = 3 , stride = 1 , padding = 1 ),
40- nn .MaxPool2d (kernel_size = 2 , stride = 2 ),
41- # In general norm layers are commonly used in Resnet-based encoder/decoders
42- # explicitly add one here with affine=False to avoid introducing new parameters
43- nn .GroupNorm (4 , 32 , affine = False ),
44- nn .Conv2d (32 , quantize_dim , kernel_size = 1 ),
45- )
46-
47- self .quantize = LFQ (dim = quantize_dim , ** vq_kwargs )
48-
49- self .decode = nn .Sequential (
50- nn .Conv2d (quantize_dim , 32 , kernel_size = 3 , stride = 1 , padding = 1 ),
51- nn .Upsample (scale_factor = 2 , mode = "nearest" ),
52- nn .Conv2d (32 , 16 , kernel_size = 3 , stride = 1 , padding = 1 ),
53- nn .GELU (),
54- nn .Upsample (scale_factor = 2 , mode = "nearest" ),
55- nn .Conv2d (16 , 1 , kernel_size = 3 , stride = 1 , padding = 1 ),
56- )
57- return
58-
59- def forward (self , x ):
60- x = self .encode (x )
61- x , indices , entropy_aux_loss = self .quantize (x )
62- x = self .decode (x )
63- return x .clamp (- 1 , 1 ), indices , entropy_aux_loss
64-
25+ def LFQAutoEncoder (
26+ codebook_size ,
27+ ** vq_kwargs
28+ ):
29+ assert log2 (codebook_size ).is_integer ()
30+ quantize_dim = int (log2 (codebook_size ))
31+
32+ return Sequential (
33+ nn .Conv2d (1 , 16 , kernel_size = 3 , stride = 1 , padding = 1 ),
34+ nn .MaxPool2d (kernel_size = 2 , stride = 2 ),
35+ nn .GELU (),
36+ nn .Conv2d (16 , 32 , kernel_size = 3 , stride = 1 , padding = 1 ),
37+ nn .MaxPool2d (kernel_size = 2 , stride = 2 ),
38+ # In general norm layers are commonly used in Resnet-based encoder/decoders
39+ # explicitly add one here with affine=False to avoid introducing new parameters
40+ nn .GroupNorm (4 , 32 , affine = False ),
41+ nn .Conv2d (32 , quantize_dim , kernel_size = 1 ),
42+ LFQ (dim = quantize_dim , ** vq_kwargs ),
43+ nn .Conv2d (quantize_dim , 32 , kernel_size = 3 , stride = 1 , padding = 1 ),
44+ nn .Upsample (scale_factor = 2 , mode = "nearest" ),
45+ nn .Conv2d (32 , 16 , kernel_size = 3 , stride = 1 , padding = 1 ),
46+ nn .GELU (),
47+ nn .Upsample (scale_factor = 2 , mode = "nearest" ),
48+ nn .Conv2d (16 , 1 , kernel_size = 3 , stride = 1 , padding = 1 ),
49+ )
6550
6651def train (model , train_loader , train_iterations = 1000 ):
6752 def iterate_dataset (data_loader ):
@@ -78,6 +63,7 @@ def iterate_dataset(data_loader):
7863 opt .zero_grad ()
7964 x , _ = next (iterate_dataset (train_loader ))
8065 out , indices , entropy_aux_loss = model (x )
66+ out = out .clamp (- 1. , 1. )
8167
8268 rec_loss = F .l1_loss (out , x )
8369 (rec_loss + entropy_aux_loss ).backward ()
@@ -88,7 +74,6 @@ def iterate_dataset(data_loader):
8874 + f"entropy aux loss: { entropy_aux_loss .item ():.3f} | "
8975 + f"active %: { indices .unique ().numel () / codebook_size * 100 :.3f} "
9076 )
91- return
9277
9378transform = transforms .Compose (
9479 [transforms .ToTensor (), transforms .Normalize ((0.5 ,), (0.5 ,))]
0 commit comments