File tree Expand file tree Collapse file tree 1 file changed +10
-2
lines changed Expand file tree Collapse file tree 1 file changed +10
-2
lines changed Original file line number Diff line number Diff line change 1414train_iter = 10000
1515num_codes = 256
1616seed = 1234
17- rotation_trick = True
17+
18+ rotation_trick = True # rotation trick instead ot straight-through
19+ use_mlp = True # use a one layer mlp with relu instead of linear
20+
1821device = "cuda" if torch .cuda .is_available () else "cpu"
1922
2023def SimVQAutoEncoder (** vq_kwargs ):
@@ -77,7 +80,12 @@ def iterate_dataset(data_loader):
7780
7881model = SimVQAutoEncoder (
7982 codebook_size = num_codes ,
80- rotation_trick = rotation_trick
83+ rotation_trick = rotation_trick ,
84+ codebook_transform = nn .Sequential (
85+ nn .Linear (32 , 128 ),
86+ nn .ReLU (),
87+ nn .Linear (128 , 32 ),
88+ ) if use_mlp else None
8189).to (device )
8290
8391opt = torch .optim .AdamW (model .parameters (), lr = lr )
You can’t perform that action at this time.
0 commit comments