@@ -58,20 +58,20 @@ def __init__(
5858 quantize_dropout = False ,
5959 quantize_dropout_cutoff_index = 0 ,
6060 quantize_dropout_multiple_of = 1 ,
61- accept_image_fmap = False ,
61+ channel_first = False ,
6262 rotation_trick = True , # rotation trick from @cfifty, on top of sim vq
6363 ** sim_vq_kwargs
6464 ):
6565 super ().__init__ ()
6666 assert heads == 1 , 'residual vq is not compatible with multi-headed codes'
6767
68- self .accept_image_fmap = accept_image_fmap
68+ self .channel_first = channel_first
6969
7070 self .num_quantizers = num_quantizers
7171
7272 # define sim vq across layers
7373
74- self .layers = ModuleList ([SimVQ (dim = dim , codebook_size = codebook_size , rotation_trick = rotation_trick , accept_image_fmap = accept_image_fmap , ** sim_vq_kwargs ) for _ in range (num_quantizers )])
74+ self .layers = ModuleList ([SimVQ (dim = dim , codebook_size = codebook_size , rotation_trick = rotation_trick , channel_first = channel_first , ** sim_vq_kwargs ) for _ in range (num_quantizers )])
7575
7676 # quantize dropout
7777
@@ -100,7 +100,7 @@ def get_codes_from_indices(self, indices):
100100
101101 batch , quantize_dim = indices .shape [0 ], indices .shape [- 1 ]
102102
103- # may also receive indices in the shape of 'b h w q' (accept_image_fmap )
103+ # may also receive indices in the shape of 'b h w q' (images )
104104
105105 indices , inverse = pack_one (indices , 'b * q' )
106106
@@ -122,11 +122,11 @@ def get_codes_from_indices(self, indices):
122122
123123 all_codes = all_codes .masked_fill (rearrange (mask , 'b n q -> q b n 1' ), 0. )
124124
125- # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
125+ # if (channel_first = True) then return shape (quantize, batch, height, width, dimension)
126126
127127 all_codes = inverse (all_codes , 'q b * d' )
128128
129- if self .accept_image_fmap :
129+ if self .channel_first :
130130 all_codes = rearrange (all_codes , 'q b ... d -> q b d ...' )
131131
132132 return all_codes
@@ -139,23 +139,17 @@ def get_output_from_indices(self, indices):
139139 def forward (
140140 self ,
141141 x ,
142- indices : Tensor | list [Tensor ] | None = None ,
143142 return_all_codes = False ,
144143 rand_quantize_dropout_fixed_seed = None
145144 ):
146- num_quant , quant_dropout_multiple_of , return_loss , device = self .num_quantizers , self .quantize_dropout_multiple_of , exists (indices ), x .device
147-
148- assert not (self .accept_image_fmap and exists (indices ))
145+ num_quant , quant_dropout_multiple_of , device = self .num_quantizers , self .quantize_dropout_multiple_of , x .device
149146
150147 quantized_out = 0.
151148 residual = x
152149
153150 all_losses = []
154151 all_indices = []
155152
156- if isinstance (indices , list ):
157- indices = torch .stack (indices )
158-
159153 should_quantize_dropout = self .training and self .quantize_dropout and not return_loss
160154
161155 # sample a layer index at which to dropout further residual quantization
@@ -175,7 +169,7 @@ def forward(
175169 if quant_dropout_multiple_of != 1 :
176170 rand_quantize_dropout_index = round_up_multiple (rand_quantize_dropout_index + 1 , quant_dropout_multiple_of ) - 1
177171
178- null_indices_shape = (x .shape [0 ], * x .shape [- 2 :]) if self .accept_image_fmap else tuple (x .shape [:2 ])
172+ null_indices_shape = (x .shape [0 ], * x .shape [- 2 :]) if self .channel_first else tuple (x .shape [:2 ])
179173 null_indices = torch .full (null_indices_shape , - 1. , device = device , dtype = torch .long )
180174 null_loss = torch .full ((1 ,), 0. , device = device , dtype = x .dtype )
181175
0 commit comments