Skip to content

Commit d9967be

Browse files
committed
received emails from confused researchers re: pytest-examples this morning. get rid of it
1 parent 85f03c5 commit d9967be

File tree

5 files changed

+333
-87
lines changed

5 files changed

+333
-87
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ jobs:
1616
run: |
1717
rye sync
1818
- name: Run pytest
19-
run: rye run pytest --cov=. tests/test_examples_readme.py
19+
run: rye run pytest --cov=. tests/

README.md

Lines changed: 52 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ vq = VectorQuantize(
2727

2828
x = torch.randn(1, 1024, 256)
2929
quantized, indices, commit_loss = vq(x) # (1, 1024, 256), (1, 1024), (1)
30-
print(quantized.shape, indices.shape, commit_loss.shape)
31-
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])
30+
3231
```
3332

3433
## Residual VQ
@@ -49,13 +48,13 @@ x = torch.randn(1, 1024, 256)
4948

5049
quantized, indices, commit_loss = residual_vq(x)
5150
print(quantized.shape, indices.shape, commit_loss.shape)
52-
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8]) torch.Size([1, 8])
51+
# (1, 1024, 256), (1, 1024, 8), (1, 8)
5352

5453
# if you need all the codes across the quantization layers, just pass return_all_codes = True
5554

5655
quantized, indices, commit_loss, all_codes = residual_vq(x, return_all_codes = True)
57-
print(all_codes.shape)
58-
#> torch.Size([8, 1, 1024, 256])
56+
57+
# (8, 1, 1024, 256)
5958
```
6059

6160
Furthermore, <a href="https://arxiv.org/abs/2203.01941">this paper</a> uses Residual-VQ to construct the RQ-VAE, for generating high resolution images with more compressed codes.
@@ -97,8 +96,8 @@ residual_vq = GroupedResidualVQ(
9796
x = torch.randn(1, 1024, 256)
9897

9998
quantized, indices, commit_loss = residual_vq(x)
100-
print(quantized.shape, indices.shape, commit_loss.shape)
101-
#> torch.Size([1, 1024, 256]) torch.Size([2, 1, 1024, 8]) torch.Size([2, 1, 8])
99+
100+
# (1, 1024, 256), (2, 1, 1024, 8), (2, 1, 8)
102101

103102
```
104103

@@ -120,8 +119,8 @@ residual_vq = ResidualVQ(
120119

121120
x = torch.randn(1, 1024, 256)
122121
quantized, indices, commit_loss = residual_vq(x)
123-
print(quantized.shape, indices.shape, commit_loss.shape)
124-
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 4]) torch.Size([1, 4])
122+
123+
# (1, 1024, 256), (1, 1024, 4), (1, 4)
125124
```
126125

127126
## Increasing codebook usage
@@ -144,8 +143,8 @@ vq = VectorQuantize(
144143

145144
x = torch.randn(1, 1024, 256)
146145
quantized, indices, commit_loss = vq(x)
147-
print(quantized.shape, indices.shape, commit_loss.shape)
148-
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])
146+
147+
# (1, 1024, 256), (1, 1024), (1,)
149148
```
150149

151150
### Cosine similarity
@@ -164,8 +163,8 @@ vq = VectorQuantize(
164163

165164
x = torch.randn(1, 1024, 256)
166165
quantized, indices, commit_loss = vq(x)
167-
print(quantized.shape, indices.shape, commit_loss.shape)
168-
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])
166+
167+
# (1, 1024, 256), (1, 1024), (1,)
169168
```
170169

171170
### Expiring stale codes
@@ -184,8 +183,8 @@ vq = VectorQuantize(
184183

185184
x = torch.randn(1, 1024, 256)
186185
quantized, indices, commit_loss = vq(x)
187-
print(quantized.shape, indices.shape, commit_loss.shape)
188-
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])
186+
187+
# (1, 1024, 256), (1, 1024), (1,)
189188
```
190189

191190
### Orthogonal regularization loss
@@ -209,9 +208,8 @@ vq = VectorQuantize(
209208

210209
img_fmap = torch.randn(1, 256, 32, 32)
211210
quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32), (1,)
211+
212212
# loss now contains the orthogonal regularization loss with the weight as assigned
213-
print(quantized.shape, indices.shape, loss.shape)
214-
#> torch.Size([1, 256, 32, 32]) torch.Size([1, 32, 32]) torch.Size([1])
215213
```
216214

217215
### Multi-headed VQ
@@ -235,8 +233,8 @@ vq = VectorQuantize(
235233

236234
img_fmap = torch.randn(1, 256, 32, 32)
237235
quantized, indices, loss = vq(img_fmap)
238-
print(quantized.shape, indices.shape, loss.shape)
239-
#> torch.Size([1, 256, 32, 32]) torch.Size([1, 32, 32, 8]) torch.Size([1])
236+
237+
# (1, 256, 32, 32), (1, 32, 32, 8), (1,)
240238

241239
```
242240

@@ -259,8 +257,8 @@ quantizer = RandomProjectionQuantizer(
259257

260258
x = torch.randn(1, 1024, 512)
261259
indices = quantizer(x)
262-
print(indices.shape)
263-
#> torch.Size([1, 1024, 16])
260+
261+
# (1, 1024, 16)
264262
```
265263

266264
This repository should also automatically synchronizing the codebooks in a multi-process setting. If somehow it isn't, please open an issue. You can override whether to synchronize codebooks or not by setting `sync_codebook = True | False`
@@ -285,16 +283,14 @@ Thanks goes out to [@sekstini](https://github.com/sekstini) for porting over thi
285283
import torch
286284
from vector_quantize_pytorch import FSQ
287285

288-
levels = [8,5,5,5] # see 4.1 and A.4.1 in the paper
289-
quantizer = FSQ(levels)
286+
quantizer = FSQ(
287+
levels = [8, 5, 5, 5]
288+
)
290289

291290
x = torch.randn(1, 1024, 4) # 4 since there are 4 levels
292291
xhat, indices = quantizer(x)
293292

294-
print(xhat.shape)
295-
#> torch.Size([1, 1024, 4])
296-
print(indices.shape)
297-
#> torch.Size([1, 1024])
293+
# (1, 1024, 4), (1, 1024)
298294

299295
assert torch.all(xhat == quantizer.indices_to_codes(indices))
300296
```
@@ -318,12 +314,12 @@ x = torch.randn(1, 1024, 256)
318314
residual_fsq.eval()
319315

320316
quantized, indices = residual_fsq(x)
321-
print(quantized.shape, indices.shape)
322-
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8])
317+
318+
# (1, 1024, 256), (1, 1024, 8)
323319

324320
quantized_out = residual_fsq.get_output_from_indices(indices)
325-
print(quantized_out.shape)
326-
#> torch.Size([1, 1024, 256])
321+
322+
# (1, 1024, 256)
327323

328324
assert torch.all(quantized == quantized_out)
329325
```
@@ -357,8 +353,8 @@ quantizer = LFQ(
357353
image_feats = torch.randn(1, 16, 32, 32)
358354

359355
quantized, indices, entropy_aux_loss = quantizer(image_feats, inv_temperature=100.) # you may want to experiment with temperature
360-
print(quantized.shape, indices.shape, entropy_aux_loss.shape)
361-
#> torch.Size([1, 16, 32, 32]) torch.Size([1, 32, 32]) torch.Size([])
356+
357+
# (1, 16, 32, 32), (1, 32, 32), ()
362358

363359
assert (quantized == quantizer.indices_to_codes(indices)).all()
364360
```
@@ -379,13 +375,12 @@ quantizer = LFQ(
379375
seq = torch.randn(1, 32, 16)
380376
quantized, *_ = quantizer(seq)
381377

382-
# assert seq.shape == quantized.shape
378+
assert seq.shape == quantized.shape
383379

384-
# video_feats = torch.randn(1, 16, 10, 32, 32)
385-
# quantized, *_ = quantizer(video_feats)
386-
387-
# assert video_feats.shape == quantized.shape
380+
video_feats = torch.randn(1, 16, 10, 32, 32)
381+
quantized, *_ = quantizer(video_feats)
388382

383+
assert video_feats.shape == quantized.shape
389384
```
390385

391386
Or support multiple codebooks
@@ -403,8 +398,8 @@ quantizer = LFQ(
403398
image_feats = torch.randn(1, 16, 32, 32)
404399

405400
quantized, indices, entropy_aux_loss = quantizer(image_feats)
406-
print(quantized.shape, indices.shape, entropy_aux_loss.shape)
407-
#> torch.Size([1, 16, 32, 32]) torch.Size([1, 32, 32, 4]) torch.Size([])
401+
402+
# (1, 16, 32, 32), (1, 32, 32, 4), ()
408403

409404
assert image_feats.shape == quantized.shape
410405
assert (quantized == quantizer.indices_to_codes(indices)).all()
@@ -427,12 +422,12 @@ x = torch.randn(1, 1024, 256)
427422
residual_lfq.eval()
428423

429424
quantized, indices, commit_loss = residual_lfq(x)
430-
print(quantized.shape, indices.shape, commit_loss.shape)
431-
#> torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8]) torch.Size([8])
425+
426+
# (1, 1024, 256), (1, 1024, 8), (8)
432427

433428
quantized_out = residual_lfq.get_output_from_indices(indices)
434-
print(quantized_out.shape)
435-
#> torch.Size([1, 1024, 256])
429+
430+
# (1, 1024, 256)
436431

437432
assert torch.all(quantized == quantized_out)
438433
```
@@ -460,8 +455,8 @@ quantizer = LatentQuantize(
460455
image_feats = torch.randn(1, 16, 32, 32)
461456

462457
quantized, indices, loss = quantizer(image_feats)
463-
print(quantized.shape, indices.shape, loss.shape)
464-
#> torch.Size([1, 16, 32, 32]) torch.Size([1, 32, 32]) torch.Size([])
458+
459+
# (1, 16, 32, 32), (1, 32, 32), ()
465460

466461
assert image_feats.shape == quantized.shape
467462
assert (quantized == quantizer.indices_to_codes(indices)).all()
@@ -483,13 +478,13 @@ quantizer = LatentQuantize(
483478

484479
seq = torch.randn(1, 32, 16)
485480
quantized, *_ = quantizer(seq)
486-
print(quantized.shape)
487-
#> torch.Size([1, 32, 16])
481+
482+
# (1, 32, 16)
488483

489484
video_feats = torch.randn(1, 16, 10, 32, 32)
490485
quantized, *_ = quantizer(video_feats)
491-
print(quantized.shape)
492-
#> torch.Size([1, 16, 10, 32, 32])
486+
487+
# (1, 16, 10, 32, 32)
493488

494489
```
495490

@@ -499,23 +494,22 @@ Or support multiple codebooks
499494
import torch
500495
from vector_quantize_pytorch import LatentQuantize
501496

502-
levels = [4, 8, 16]
503-
dim = 9
504-
num_codebooks = 3
505-
506-
model = LatentQuantize(levels, dim, num_codebooks=num_codebooks)
497+
model = LatentQuantize(
498+
levels = [4, 8, 16],
499+
dim = 9,
500+
num_codebooks = 3
501+
)
507502

508503
input_tensor = torch.randn(2, 3, dim)
509504
output_tensor, indices, loss = model(input_tensor)
510-
print(output_tensor.shape, indices.shape, loss.shape)
511-
#> torch.Size([2, 3, 9]) torch.Size([2, 3, 3]) torch.Size([])
505+
506+
# (2, 3, 9), (2, 3, 3), ()
512507

513508
assert output_tensor.shape == input_tensor.shape
514509
assert indices.shape == (2, 3, num_codebooks)
515510
assert loss.item() >= 0
516511
```
517512

518-
519513
## Citations
520514

521515
```bibtex

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ managed = true
4444
dev-dependencies = [
4545
"ruff>=0.4.2",
4646
"pytest>=8.2.0",
47-
"pytest-examples>=0.0.10",
4847
"pytest-cov>=5.0.0",
4948
]
5049

tests/test_examples_readme.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

0 commit comments

Comments
 (0)