@@ -27,8 +27,7 @@ vq = VectorQuantize(
2727
2828x = torch.randn(1 , 1024 , 256 )
2929quantized, indices, commit_loss = vq(x) # (1, 1024, 256), (1, 1024), (1)
30- print (quantized.shape, indices.shape, commit_loss.shape)
31- # > torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])
30+
3231```
3332
3433## Residual VQ
@@ -49,13 +48,13 @@ x = torch.randn(1, 1024, 256)
4948
5049quantized, indices, commit_loss = residual_vq(x)
5150print (quantized.shape, indices.shape, commit_loss.shape)
52- # > torch.Size([ 1, 1024, 256]) torch.Size([ 1, 1024, 8]) torch.Size([ 1, 8] )
51+ # ( 1, 1024, 256), ( 1, 1024, 8), ( 1, 8)
5352
5453# if you need all the codes across the quantization layers, just pass return_all_codes = True
5554
5655quantized, indices, commit_loss, all_codes = residual_vq(x, return_all_codes = True )
57- print (all_codes.shape)
58- # > torch.Size([ 8, 1, 1024, 256] )
56+
57+ # ( 8, 1, 1024, 256)
5958```
6059
6160Furthermore, <a href =" https://arxiv.org/abs/2203.01941 " >this paper</a > uses Residual-VQ to construct the RQ-VAE, for generating high resolution images with more compressed codes.
@@ -77,8 +76,8 @@ residual_vq = ResidualVQ(
7776
7877x = torch.randn(1 , 1024 , 256 )
7978quantized, indices, commit_loss = residual_vq(x)
80- print (quantized.shape, indices.shape, commit_loss.shape)
81- # > torch.Size([ 1, 1024, 256]) torch.Size([ 1, 1024, 8]) torch.Size([ 1, 8] )
79+
80+ # ( 1, 1024, 256), ( 1, 1024, 8), ( 1, 8)
8281```
8382
8483<a href =" https://arxiv.org/abs/2305.02765 " >A recent paper</a > further proposes to do residual VQ on groups of the feature dimension, showing equivalent results to Encodec while using far fewer codebooks. You can use it by importing ` GroupedResidualVQ `
@@ -97,9 +96,8 @@ residual_vq = GroupedResidualVQ(
9796x = torch.randn(1 , 1024 , 256 )
9897
9998quantized, indices, commit_loss = residual_vq(x)
100- print (quantized.shape, indices.shape, commit_loss.shape)
101- # > torch.Size([1, 1024, 256]) torch.Size([2, 1, 1024, 8]) torch.Size([2, 1, 8])
10299
100+ # (1, 1024, 256), (2, 1, 1024, 8), (2, 1, 8)
103101```
104102
105103## Initialization
@@ -120,8 +118,8 @@ residual_vq = ResidualVQ(
120118
121119x = torch.randn(1 , 1024 , 256 )
122120quantized, indices, commit_loss = residual_vq(x)
123- print (quantized.shape, indices.shape, commit_loss.shape)
124- # > torch.Size([ 1, 1024, 256]) torch.Size([ 1, 1024, 4]) torch.Size([ 1, 4] )
121+
122+ # ( 1, 1024, 256), ( 1, 1024, 4), ( 1, 4)
125123```
126124
127125## Increasing codebook usage
@@ -144,8 +142,8 @@ vq = VectorQuantize(
144142
145143x = torch.randn(1 , 1024 , 256 )
146144quantized, indices, commit_loss = vq(x)
147- print (quantized.shape, indices.shape, commit_loss.shape)
148- # > torch.Size([ 1, 1024, 256]) torch.Size([ 1, 1024]) torch.Size([1] )
145+
146+ # ( 1, 1024, 256), ( 1, 1024), (1, )
149147```
150148
151149### Cosine similarity
@@ -164,8 +162,8 @@ vq = VectorQuantize(
164162
165163x = torch.randn(1 , 1024 , 256 )
166164quantized, indices, commit_loss = vq(x)
167- print (quantized.shape, indices.shape, commit_loss.shape)
168- # > torch.Size([ 1, 1024, 256]) torch.Size([ 1, 1024]) torch.Size([1] )
165+
166+ # ( 1, 1024, 256), ( 1, 1024), (1, )
169167```
170168
171169### Expiring stale codes
@@ -184,8 +182,8 @@ vq = VectorQuantize(
184182
185183x = torch.randn(1 , 1024 , 256 )
186184quantized, indices, commit_loss = vq(x)
187- print (quantized.shape, indices.shape, commit_loss.shape)
188- # > torch.Size([ 1, 1024, 256]) torch.Size([ 1, 1024]) torch.Size([1] )
185+
186+ # ( 1, 1024, 256), ( 1, 1024), (1, )
189187```
190188
191189### Orthogonal regularization loss
@@ -209,9 +207,8 @@ vq = VectorQuantize(
209207
210208img_fmap = torch.randn(1 , 256 , 32 , 32 )
211209quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32), (1,)
210+
212211# loss now contains the orthogonal regularization loss with the weight as assigned
213- print (quantized.shape, indices.shape, loss.shape)
214- # > torch.Size([1, 256, 32, 32]) torch.Size([1, 32, 32]) torch.Size([1])
215212```
216213
217214### Multi-headed VQ
@@ -235,8 +232,8 @@ vq = VectorQuantize(
235232
236233img_fmap = torch.randn(1 , 256 , 32 , 32 )
237234quantized, indices, loss = vq(img_fmap)
238- print (quantized.shape, indices.shape, loss.shape)
239- # > torch.Size([ 1, 256, 32, 32]) torch.Size([ 1, 32, 32, 8]) torch.Size([1] )
235+
236+ # ( 1, 256, 32, 32), ( 1, 32, 32, 8), (1, )
240237
241238```
242239
@@ -259,8 +256,8 @@ quantizer = RandomProjectionQuantizer(
259256
260257x = torch.randn(1 , 1024 , 512 )
261258indices = quantizer(x)
262- print (indices.shape)
263- # > torch.Size([ 1, 1024, 16] )
259+
260+ # ( 1, 1024, 16)
264261```
265262
266263This repository should also automatically synchronizing the codebooks in a multi-process setting. If somehow it isn't, please open an issue. You can override whether to synchronize codebooks or not by setting ` sync_codebook = True | False `
@@ -285,16 +282,14 @@ Thanks goes out to [@sekstini](https://github.com/sekstini) for porting over thi
285282import torch
286283from vector_quantize_pytorch import FSQ
287284
288- levels = [8 ,5 ,5 ,5 ] # see 4.1 and A.4.1 in the paper
289- quantizer = FSQ(levels)
285+ quantizer = FSQ(
286+ levels = [8 , 5 , 5 , 5 ]
287+ )
290288
291289x = torch.randn(1 , 1024 , 4 ) # 4 since there are 4 levels
292290xhat, indices = quantizer(x)
293291
294- print (xhat.shape)
295- # > torch.Size([1, 1024, 4])
296- print (indices.shape)
297- # > torch.Size([1, 1024])
292+ # (1, 1024, 4), (1, 1024)
298293
299294assert torch.all(xhat == quantizer.indices_to_codes(indices))
300295```
@@ -318,12 +313,12 @@ x = torch.randn(1, 1024, 256)
318313residual_fsq.eval()
319314
320315quantized, indices = residual_fsq(x)
321- print (quantized.shape, indices.shape)
322- # > torch.Size([ 1, 1024, 256]) torch.Size([ 1, 1024, 8] )
316+
317+ # ( 1, 1024, 256), ( 1, 1024, 8)
323318
324319quantized_out = residual_fsq.get_output_from_indices(indices)
325- print (quantized_out.shape)
326- # > torch.Size([ 1, 1024, 256] )
320+
321+ # ( 1, 1024, 256)
327322
328323assert torch.all(quantized == quantized_out)
329324```
@@ -357,8 +352,8 @@ quantizer = LFQ(
357352image_feats = torch.randn(1 , 16 , 32 , 32 )
358353
359354quantized, indices, entropy_aux_loss = quantizer(image_feats, inv_temperature = 100 .) # you may want to experiment with temperature
360- print (quantized.shape, indices.shape, entropy_aux_loss.shape)
361- # > torch.Size([ 1, 16, 32, 32]) torch.Size([ 1, 32, 32]) torch.Size([] )
355+
356+ # ( 1, 16, 32, 32), ( 1, 32, 32), ( )
362357
363358assert (quantized == quantizer.indices_to_codes(indices)).all()
364359```
@@ -379,13 +374,12 @@ quantizer = LFQ(
379374seq = torch.randn(1 , 32 , 16 )
380375quantized, * _ = quantizer(seq)
381376
382- # assert seq.shape == quantized.shape
377+ assert seq.shape == quantized.shape
383378
384- # video_feats = torch.randn(1, 16, 10, 32, 32)
385- # quantized, *_ = quantizer(video_feats)
386-
387- # assert video_feats.shape == quantized.shape
379+ video_feats = torch.randn(1 , 16 , 10 , 32 , 32 )
380+ quantized, * _ = quantizer(video_feats)
388381
382+ assert video_feats.shape == quantized.shape
389383```
390384
391385Or support multiple codebooks
@@ -403,8 +397,8 @@ quantizer = LFQ(
403397image_feats = torch.randn(1 , 16 , 32 , 32 )
404398
405399quantized, indices, entropy_aux_loss = quantizer(image_feats)
406- print (quantized.shape, indices.shape, entropy_aux_loss.shape)
407- # > torch.Size([ 1, 16, 32, 32]) torch.Size([ 1, 32, 32, 4]) torch.Size([] )
400+
401+ # ( 1, 16, 32, 32), ( 1, 32, 32, 4), ( )
408402
409403assert image_feats.shape == quantized.shape
410404assert (quantized == quantizer.indices_to_codes(indices)).all()
@@ -427,12 +421,12 @@ x = torch.randn(1, 1024, 256)
427421residual_lfq.eval()
428422
429423quantized, indices, commit_loss = residual_lfq(x)
430- print (quantized.shape, indices.shape, commit_loss.shape)
431- # > torch.Size([ 1, 1024, 256]) torch.Size([ 1, 1024, 8]) torch.Size([8] )
424+
425+ # ( 1, 1024, 256), ( 1, 1024, 8), (8 )
432426
433427quantized_out = residual_lfq.get_output_from_indices(indices)
434- print (quantized_out.shape)
435- # > torch.Size([ 1, 1024, 256] )
428+
429+ # ( 1, 1024, 256)
436430
437431assert torch.all(quantized == quantized_out)
438432```
@@ -460,8 +454,8 @@ quantizer = LatentQuantize(
460454image_feats = torch.randn(1 , 16 , 32 , 32 )
461455
462456quantized, indices, loss = quantizer(image_feats)
463- print (quantized.shape, indices.shape, loss.shape)
464- # > torch.Size([ 1, 16, 32, 32]) torch.Size([ 1, 32, 32]) torch.Size([] )
457+
458+ # ( 1, 16, 32, 32), ( 1, 32, 32), ( )
465459
466460assert image_feats.shape == quantized.shape
467461assert (quantized == quantizer.indices_to_codes(indices)).all()
@@ -483,13 +477,13 @@ quantizer = LatentQuantize(
483477
484478seq = torch.randn(1 , 32 , 16 )
485479quantized, * _ = quantizer(seq)
486- print (quantized.shape)
487- # > torch.Size([ 1, 32, 16] )
480+
481+ # ( 1, 32, 16)
488482
489483video_feats = torch.randn(1 , 16 , 10 , 32 , 32 )
490484quantized, * _ = quantizer(video_feats)
491- print (quantized.shape)
492- # > torch.Size([ 1, 16, 10, 32, 32] )
485+
486+ # ( 1, 16, 10, 32, 32)
493487
494488```
495489
@@ -499,23 +493,22 @@ Or support multiple codebooks
499493import torch
500494from vector_quantize_pytorch import LatentQuantize
501495
502- levels = [ 4 , 8 , 16 ]
503- dim = 9
504- num_codebooks = 3
505-
506- model = LatentQuantize(levels, dim, num_codebooks = num_codebooks )
496+ model = LatentQuantize(
497+ levels = [ 4 , 8 , 16 ],
498+ dim = 9 ,
499+ num_codebooks = 3
500+ )
507501
508502input_tensor = torch.randn(2 , 3 , dim)
509503output_tensor, indices, loss = model(input_tensor)
510- print (output_tensor.shape, indices.shape, loss.shape)
511- # > torch.Size([ 2, 3, 9]) torch.Size([ 2, 3, 3]) torch.Size([] )
504+
505+ # ( 2, 3, 9), ( 2, 3, 3), ( )
512506
513507assert output_tensor.shape == input_tensor.shape
514508assert indices.shape == (2 , 3 , num_codebooks)
515509assert loss.item() >= 0
516510```
517511
518-
519512## Citations
520513
521514``` bibtex
0 commit comments