@@ -27,8 +27,7 @@ vq = VectorQuantize(
2727
2828x = torch.randn(1 , 1024 , 256 )
2929quantized, indices, commit_loss = vq(x) # (1, 1024, 256), (1, 1024), (1)
30- print (quantized.shape, indices.shape, commit_loss.shape)
31- # > torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])
30+
3231```
3332
3433## Residual VQ
@@ -49,13 +48,13 @@ x = torch.randn(1, 1024, 256)
4948
5049quantized, indices, commit_loss = residual_vq(x)
5150print (quantized.shape, indices.shape, commit_loss.shape)
52- # > torch.Size([ 1, 1024, 256]) torch.Size([ 1, 1024, 8]) torch.Size([ 1, 8] )
51+ # ( 1, 1024, 256), ( 1, 1024, 8), ( 1, 8)
5352
5453# if you need all the codes across the quantization layers, just pass return_all_codes = True
5554
5655quantized, indices, commit_loss, all_codes = residual_vq(x, return_all_codes = True )
57- print (all_codes.shape)
58- # > torch.Size([ 8, 1, 1024, 256] )
56+
57+ # ( 8, 1, 1024, 256)
5958```
6059
6160Furthermore, <a href =" https://arxiv.org/abs/2203.01941 " >this paper</a > uses Residual-VQ to construct the RQ-VAE, for generating high resolution images with more compressed codes.
@@ -97,8 +96,8 @@ residual_vq = GroupedResidualVQ(
9796x = torch.randn(1 , 1024 , 256 )
9897
9998quantized, indices, commit_loss = residual_vq(x)
100- print (quantized.shape, indices.shape, commit_loss.shape)
101- # > torch.Size([ 1, 1024, 256]) torch.Size([ 2, 1, 1024, 8]) torch.Size([ 2, 1, 8] )
99+
100+ # ( 1, 1024, 256), ( 2, 1, 1024, 8), ( 2, 1, 8)
102101
103102```
104103
@@ -120,8 +119,8 @@ residual_vq = ResidualVQ(
120119
121120x = torch.randn(1 , 1024 , 256 )
122121quantized, indices, commit_loss = residual_vq(x)
123- print (quantized.shape, indices.shape, commit_loss.shape)
124- # > torch.Size([ 1, 1024, 256]) torch.Size([ 1, 1024, 4]) torch.Size([ 1, 4] )
122+
123+ # ( 1, 1024, 256), ( 1, 1024, 4), ( 1, 4)
125124```
126125
127126## Increasing codebook usage
@@ -144,8 +143,8 @@ vq = VectorQuantize(
144143
145144x = torch.randn(1 , 1024 , 256 )
146145quantized, indices, commit_loss = vq(x)
147- print (quantized.shape, indices.shape, commit_loss.shape)
148- # > torch.Size([ 1, 1024, 256]) torch.Size([ 1, 1024]) torch.Size([1] )
146+
147+ # ( 1, 1024, 256), ( 1, 1024), (1, )
149148```
150149
151150### Cosine similarity
@@ -164,8 +163,8 @@ vq = VectorQuantize(
164163
165164x = torch.randn(1 , 1024 , 256 )
166165quantized, indices, commit_loss = vq(x)
167- print (quantized.shape, indices.shape, commit_loss.shape)
168- # > torch.Size([ 1, 1024, 256]) torch.Size([ 1, 1024]) torch.Size([1] )
166+
167+ # ( 1, 1024, 256), ( 1, 1024), (1, )
169168```
170169
171170### Expiring stale codes
@@ -184,8 +183,8 @@ vq = VectorQuantize(
184183
185184x = torch.randn(1 , 1024 , 256 )
186185quantized, indices, commit_loss = vq(x)
187- print (quantized.shape, indices.shape, commit_loss.shape)
188- # > torch.Size([ 1, 1024, 256]) torch.Size([ 1, 1024]) torch.Size([1] )
186+
187+ # ( 1, 1024, 256), ( 1, 1024), (1, )
189188```
190189
191190### Orthogonal regularization loss
@@ -209,9 +208,8 @@ vq = VectorQuantize(
209208
210209img_fmap = torch.randn(1 , 256 , 32 , 32 )
211210quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32), (1,)
211+
212212# loss now contains the orthogonal regularization loss with the weight as assigned
213- print (quantized.shape, indices.shape, loss.shape)
214- # > torch.Size([1, 256, 32, 32]) torch.Size([1, 32, 32]) torch.Size([1])
215213```
216214
217215### Multi-headed VQ
@@ -235,8 +233,8 @@ vq = VectorQuantize(
235233
236234img_fmap = torch.randn(1 , 256 , 32 , 32 )
237235quantized, indices, loss = vq(img_fmap)
238- print (quantized.shape, indices.shape, loss.shape)
239- # > torch.Size([ 1, 256, 32, 32]) torch.Size([ 1, 32, 32, 8]) torch.Size([1] )
236+
237+ # ( 1, 256, 32, 32), ( 1, 32, 32, 8), (1, )
240238
241239```
242240
@@ -259,8 +257,8 @@ quantizer = RandomProjectionQuantizer(
259257
260258x = torch.randn(1 , 1024 , 512 )
261259indices = quantizer(x)
262- print (indices.shape)
263- # > torch.Size([ 1, 1024, 16] )
260+
261+ # ( 1, 1024, 16)
264262```
265263
266264This repository should also automatically synchronizing the codebooks in a multi-process setting. If somehow it isn't, please open an issue. You can override whether to synchronize codebooks or not by setting ` sync_codebook = True | False `
@@ -285,16 +283,14 @@ Thanks goes out to [@sekstini](https://github.com/sekstini) for porting over thi
285283import torch
286284from vector_quantize_pytorch import FSQ
287285
288- levels = [8 ,5 ,5 ,5 ] # see 4.1 and A.4.1 in the paper
289- quantizer = FSQ(levels)
286+ quantizer = FSQ(
287+ levels = [8 , 5 , 5 , 5 ]
288+ )
290289
291290x = torch.randn(1 , 1024 , 4 ) # 4 since there are 4 levels
292291xhat, indices = quantizer(x)
293292
294- print (xhat.shape)
295- # > torch.Size([1, 1024, 4])
296- print (indices.shape)
297- # > torch.Size([1, 1024])
293+ # (1, 1024, 4), (1, 1024)
298294
299295assert torch.all(xhat == quantizer.indices_to_codes(indices))
300296```
@@ -318,12 +314,12 @@ x = torch.randn(1, 1024, 256)
318314residual_fsq.eval()
319315
320316quantized, indices = residual_fsq(x)
321- print (quantized.shape, indices.shape)
322- # > torch.Size([ 1, 1024, 256]) torch.Size([ 1, 1024, 8] )
317+
318+ # ( 1, 1024, 256), ( 1, 1024, 8)
323319
324320quantized_out = residual_fsq.get_output_from_indices(indices)
325- print (quantized_out.shape)
326- # > torch.Size([ 1, 1024, 256] )
321+
322+ # ( 1, 1024, 256)
327323
328324assert torch.all(quantized == quantized_out)
329325```
@@ -357,8 +353,8 @@ quantizer = LFQ(
357353image_feats = torch.randn(1 , 16 , 32 , 32 )
358354
359355quantized, indices, entropy_aux_loss = quantizer(image_feats, inv_temperature = 100 .) # you may want to experiment with temperature
360- print (quantized.shape, indices.shape, entropy_aux_loss.shape)
361- # > torch.Size([ 1, 16, 32, 32]) torch.Size([ 1, 32, 32]) torch.Size([] )
356+
357+ # ( 1, 16, 32, 32), ( 1, 32, 32), ( )
362358
363359assert (quantized == quantizer.indices_to_codes(indices)).all()
364360```
@@ -379,13 +375,12 @@ quantizer = LFQ(
379375seq = torch.randn(1 , 32 , 16 )
380376quantized, * _ = quantizer(seq)
381377
382- # assert seq.shape == quantized.shape
378+ assert seq.shape == quantized.shape
383379
384- # video_feats = torch.randn(1, 16, 10, 32, 32)
385- # quantized, *_ = quantizer(video_feats)
386-
387- # assert video_feats.shape == quantized.shape
380+ video_feats = torch.randn(1 , 16 , 10 , 32 , 32 )
381+ quantized, * _ = quantizer(video_feats)
388382
383+ assert video_feats.shape == quantized.shape
389384```
390385
391386Or support multiple codebooks
@@ -403,8 +398,8 @@ quantizer = LFQ(
403398image_feats = torch.randn(1 , 16 , 32 , 32 )
404399
405400quantized, indices, entropy_aux_loss = quantizer(image_feats)
406- print (quantized.shape, indices.shape, entropy_aux_loss.shape)
407- # > torch.Size([ 1, 16, 32, 32]) torch.Size([ 1, 32, 32, 4]) torch.Size([] )
401+
402+ # ( 1, 16, 32, 32), ( 1, 32, 32, 4), ( )
408403
409404assert image_feats.shape == quantized.shape
410405assert (quantized == quantizer.indices_to_codes(indices)).all()
@@ -427,12 +422,12 @@ x = torch.randn(1, 1024, 256)
427422residual_lfq.eval()
428423
429424quantized, indices, commit_loss = residual_lfq(x)
430- print (quantized.shape, indices.shape, commit_loss.shape)
431- # > torch.Size([ 1, 1024, 256]) torch.Size([ 1, 1024, 8]) torch.Size([8] )
425+
426+ # ( 1, 1024, 256), ( 1, 1024, 8), (8 )
432427
433428quantized_out = residual_lfq.get_output_from_indices(indices)
434- print (quantized_out.shape)
435- # > torch.Size([ 1, 1024, 256] )
429+
430+ # ( 1, 1024, 256)
436431
437432assert torch.all(quantized == quantized_out)
438433```
@@ -460,8 +455,8 @@ quantizer = LatentQuantize(
460455image_feats = torch.randn(1 , 16 , 32 , 32 )
461456
462457quantized, indices, loss = quantizer(image_feats)
463- print (quantized.shape, indices.shape, loss.shape)
464- # > torch.Size([ 1, 16, 32, 32]) torch.Size([ 1, 32, 32]) torch.Size([] )
458+
459+ # ( 1, 16, 32, 32), ( 1, 32, 32), ( )
465460
466461assert image_feats.shape == quantized.shape
467462assert (quantized == quantizer.indices_to_codes(indices)).all()
@@ -483,13 +478,13 @@ quantizer = LatentQuantize(
483478
484479seq = torch.randn(1 , 32 , 16 )
485480quantized, * _ = quantizer(seq)
486- print (quantized.shape)
487- # > torch.Size([ 1, 32, 16] )
481+
482+ # ( 1, 32, 16)
488483
489484video_feats = torch.randn(1 , 16 , 10 , 32 , 32 )
490485quantized, * _ = quantizer(video_feats)
491- print (quantized.shape)
492- # > torch.Size([ 1, 16, 10, 32, 32] )
486+
487+ # ( 1, 16, 10, 32, 32)
493488
494489```
495490
@@ -499,23 +494,22 @@ Or support multiple codebooks
499494import torch
500495from vector_quantize_pytorch import LatentQuantize
501496
502- levels = [ 4 , 8 , 16 ]
503- dim = 9
504- num_codebooks = 3
505-
506- model = LatentQuantize(levels, dim, num_codebooks = num_codebooks )
497+ model = LatentQuantize(
498+ levels = [ 4 , 8 , 16 ],
499+ dim = 9 ,
500+ num_codebooks = 3
501+ )
507502
508503input_tensor = torch.randn(2 , 3 , dim)
509504output_tensor, indices, loss = model(input_tensor)
510- print (output_tensor.shape, indices.shape, loss.shape)
511- # > torch.Size([ 2, 3, 9]) torch.Size([ 2, 3, 3]) torch.Size([] )
505+
506+ # ( 2, 3, 9), ( 2, 3, 3), ( )
512507
513508assert output_tensor.shape == input_tensor.shape
514509assert indices.shape == (2 , 3 , num_codebooks)
515510assert loss.item() >= 0
516511```
517512
518-
519513## Citations
520514
521515``` bibtex
0 commit comments