File tree Expand file tree Collapse file tree 2 files changed +13
-2
lines changed Expand file tree Collapse file tree 2 files changed +13
-2
lines changed Original file line number Diff line number Diff line change 11[project ]
22name = " vector-quantize-pytorch"
3- version = " 1.20.1 "
3+ version = " 1.20.2 "
44description = " Vector Quantization - Pytorch"
55authors = [
66 { name = " Phil Wang" , email = " lucidrains@gmail.com" }
Original file line number Diff line number Diff line change 1+ from __future__ import annotations
12from typing import Callable
23
34import torch
@@ -38,6 +39,7 @@ def __init__(
3839 self ,
3940 dim ,
4041 codebook_size ,
42+ codebook_transform : Module | None = None ,
4143 init_fn : Callable = identity ,
4244 accept_image_fmap = False ,
4345 rotation_trick = True , # works even better with rotation trick turned on, with no straight through and the commit loss from input to quantize
@@ -51,7 +53,11 @@ def __init__(
5153
5254 # the codebook is actually implicit from a linear layer from frozen gaussian or uniform
5355
54- self .codebook_to_codes = nn .Linear (dim , dim , bias = False )
56+ if not exists (codebook_transform ):
57+ codebook_transform = nn .Linear (dim , dim , bias = False )
58+
59+ self .codebook_to_codes = codebook_transform
60+
5561 self .register_buffer ('codebook' , codebook )
5662
5763
@@ -114,6 +120,11 @@ def forward(
114120
115121 sim_vq = SimVQ (
116122 dim = 512 ,
123+ codebook_transform = nn .Sequential (
124+ nn .Linear (512 , 1024 ),
125+ nn .ReLU (),
126+ nn .Linear (1024 , 512 )
127+ ),
117128 codebook_size = 1024 ,
118129 accept_image_fmap = True
119130 )
You can’t perform that action at this time.
0 commit comments