|
| 1 | +from typing import Callable |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import nn |
| 5 | +from torch.nn import Module |
| 6 | +import torch.nn.functional as F |
| 7 | + |
| 8 | +from einx import get_at |
| 9 | +from einops import rearrange, pack, unpack |
| 10 | + |
| 11 | +from vector_quantize_pytorch.vector_quantize_pytorch import rotate_from_to |
| 12 | + |
| 13 | +# helper functions |
| 14 | + |
| 15 | +def exists(v): |
| 16 | + return v is not None |
| 17 | + |
| 18 | +def identity(t): |
| 19 | + return t |
| 20 | + |
| 21 | +def default(v, d): |
| 22 | + return v if exists(v) else d |
| 23 | + |
| 24 | +def pack_one(t, pattern): |
| 25 | + packed, packed_shape = pack([t], pattern) |
| 26 | + |
| 27 | + def inverse(out, inv_pattern = None): |
| 28 | + inv_pattern = default(inv_pattern, pattern) |
| 29 | + out, = unpack(out, packed_shape, inv_pattern) |
| 30 | + return out |
| 31 | + |
| 32 | + return packed, inverse |
| 33 | + |
| 34 | +# class |
| 35 | + |
| 36 | +class SimVQ(Module): |
| 37 | + def __init__( |
| 38 | + self, |
| 39 | + dim, |
| 40 | + codebook_size, |
| 41 | + init_fn: Callable = identity, |
| 42 | + accept_image_fmap = False, |
| 43 | + rotation_trick = True, # works even better with rotation trick turned on, with no asymmetric commit loss or straight through |
| 44 | + commit_loss_input_to_quantize_weight = 0.25, |
| 45 | + ): |
| 46 | + super().__init__() |
| 47 | + self.accept_image_fmap = accept_image_fmap |
| 48 | + |
| 49 | + codebook = torch.randn(codebook_size, dim) * (dim ** -0.5) |
| 50 | + codebook = init_fn(codebook) |
| 51 | + |
| 52 | + # the codebook is actually implicit from a linear layer from frozen gaussian or uniform |
| 53 | + |
| 54 | + self.codebook_to_codes = nn.Linear(dim, dim, bias = False) |
| 55 | + self.register_buffer('codebook', codebook) |
| 56 | + |
| 57 | + |
| 58 | + # whether to use rotation trick from Fifty et al. |
| 59 | + # https://arxiv.org/abs/2410.06424 |
| 60 | + |
| 61 | + self.rotation_trick = rotation_trick |
| 62 | + self.register_buffer('zero', torch.tensor(0.), persistent = False) |
| 63 | + |
| 64 | + # commit loss weighting - weighing input to quantize a bit less is crucial for it to work |
| 65 | + |
| 66 | + self.commit_loss_input_to_quantize_weight = commit_loss_input_to_quantize_weight |
| 67 | + |
| 68 | + def forward( |
| 69 | + self, |
| 70 | + x |
| 71 | + ): |
| 72 | + if self.accept_image_fmap: |
| 73 | + x = rearrange(x, 'b d h w -> b h w d') |
| 74 | + x, inverse_pack = pack_one(x, 'b * d') |
| 75 | + |
| 76 | + implicit_codebook = self.codebook_to_codes(self.codebook) |
| 77 | + |
| 78 | + with torch.no_grad(): |
| 79 | + dist = torch.cdist(x, implicit_codebook) |
| 80 | + indices = dist.argmin(dim = -1) |
| 81 | + |
| 82 | + # select codes |
| 83 | + |
| 84 | + quantized = get_at('[c] d, b n -> b n d', implicit_codebook, indices) |
| 85 | + |
| 86 | + if self.rotation_trick: |
| 87 | + # rotation trick from @cfifty |
| 88 | + |
| 89 | + quantized = rotate_from_to(quantized, x) |
| 90 | + |
| 91 | + commit_loss = self.zero |
| 92 | + else: |
| 93 | + # commit loss and straight through, as was done in the paper |
| 94 | + |
| 95 | + commit_loss = ( |
| 96 | + F.mse_loss(x, quantized.detach()) * self.commit_loss_input_to_quantize_weight + |
| 97 | + F.mse_loss(x.detach(), quantized) |
| 98 | + ) |
| 99 | + |
| 100 | + quantized = (quantized - x).detach() + x |
| 101 | + |
| 102 | + if self.accept_image_fmap: |
| 103 | + quantized = inverse_pack(quantized) |
| 104 | + quantized = rearrange(quantized, 'b h w d-> b d h w') |
| 105 | + |
| 106 | + indices = inverse_pack(indices, 'b *') |
| 107 | + |
| 108 | + return quantized, indices, commit_loss |
| 109 | + |
| 110 | +# main |
| 111 | + |
| 112 | +if __name__ == '__main__': |
| 113 | + |
| 114 | + x = torch.randn(1, 512, 32, 32) |
| 115 | + |
| 116 | + sim_vq = SimVQ( |
| 117 | + dim = 512, |
| 118 | + codebook_size = 1024, |
| 119 | + accept_image_fmap = True |
| 120 | + ) |
| 121 | + |
| 122 | + quantized, indices, commit_loss = sim_vq(x) |
| 123 | + |
| 124 | + assert x.shape == quantized.shape |
0 commit comments