|
| 1 | +import torch |
| 2 | +import torch.nn.functional as F |
| 3 | +from torch.nn.utils.rnn import pad_sequence |
| 4 | +from torch import nn, einsum |
| 5 | + |
| 6 | +from einops import rearrange, repeat |
| 7 | +from einops.layers.torch import Rearrange |
| 8 | + |
| 9 | +# helpers |
| 10 | + |
| 11 | +def exists(val): |
| 12 | + return val is not None |
| 13 | + |
| 14 | +def pair(t): |
| 15 | + return t if isinstance(t, tuple) else (t, t) |
| 16 | + |
| 17 | +# adaptive token sampling functions and classes |
| 18 | + |
| 19 | +def log(t, eps = 1e-6): |
| 20 | + return torch.log(t + eps) |
| 21 | + |
| 22 | +def sample_gumbel(shape, device, dtype, eps = 1e-6): |
| 23 | + u = torch.empty(shape, device = device, dtype = dtype).uniform_(0, 1) |
| 24 | + return -log(-log(u, eps), eps) |
| 25 | + |
| 26 | +def batched_index_select(values, indices, dim = 1): |
| 27 | + value_dims = values.shape[(dim + 1):] |
| 28 | + values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices)) |
| 29 | + indices = indices[(..., *((None,) * len(value_dims)))] |
| 30 | + indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims) |
| 31 | + value_expand_len = len(indices_shape) - (dim + 1) |
| 32 | + values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)] |
| 33 | + |
| 34 | + value_expand_shape = [-1] * len(values.shape) |
| 35 | + expand_slice = slice(dim, (dim + value_expand_len)) |
| 36 | + value_expand_shape[expand_slice] = indices.shape[expand_slice] |
| 37 | + values = values.expand(*value_expand_shape) |
| 38 | + |
| 39 | + dim += value_expand_len |
| 40 | + return values.gather(dim, indices) |
| 41 | + |
| 42 | +class AdaptiveTokenSampling(nn.Module): |
| 43 | + def __init__(self, output_num_tokens, eps = 1e-6): |
| 44 | + super().__init__() |
| 45 | + self.eps = eps |
| 46 | + self.output_num_tokens = output_num_tokens |
| 47 | + |
| 48 | + def forward(self, attn, value, mask): |
| 49 | + heads, output_num_tokens, eps, device, dtype = attn.shape[1], self.output_num_tokens, self.eps, attn.device, attn.dtype |
| 50 | + |
| 51 | + # first get the attention values for CLS token to all other tokens |
| 52 | + |
| 53 | + cls_attn = attn[..., 0, 1:] |
| 54 | + |
| 55 | + # calculate the norms of the values, for weighting the scores, as described in the paper |
| 56 | + |
| 57 | + value_norms = value[..., 1:, :].norm(dim = -1) |
| 58 | + |
| 59 | + # weigh the attention scores by the norm of the values, sum across all heads |
| 60 | + |
| 61 | + cls_attn = einsum('b h n, b h n -> b n', cls_attn, value_norms) |
| 62 | + |
| 63 | + # normalize to 1 |
| 64 | + |
| 65 | + normed_cls_attn = cls_attn / (cls_attn.sum(dim = -1, keepdim = True) + eps) |
| 66 | + |
| 67 | + # instead of using inverse transform sampling, going to invert the softmax and use gumbel-max sampling instead |
| 68 | + |
| 69 | + pseudo_logits = log(normed_cls_attn) |
| 70 | + |
| 71 | + # mask out pseudo logits for gumbel-max sampling |
| 72 | + |
| 73 | + mask_without_cls = mask[:, 1:] |
| 74 | + mask_value = -torch.finfo(attn.dtype).max / 2 |
| 75 | + pseudo_logits = pseudo_logits.masked_fill(~mask_without_cls, mask_value) |
| 76 | + |
| 77 | + # expand k times, k being the adaptive sampling number |
| 78 | + |
| 79 | + pseudo_logits = repeat(pseudo_logits, 'b n -> b k n', k = output_num_tokens) |
| 80 | + pseudo_logits = pseudo_logits + sample_gumbel(pseudo_logits.shape, device = device, dtype = dtype) |
| 81 | + |
| 82 | + # gumble-max and add one to reserve 0 for padding / mask |
| 83 | + |
| 84 | + sampled_token_ids = pseudo_logits.argmax(dim = -1) + 1 |
| 85 | + |
| 86 | + # calculate unique using torch.unique and then pad the sequence from the right |
| 87 | + |
| 88 | + unique_sampled_token_ids_list = [torch.unique(t, sorted = True) for t in torch.unbind(sampled_token_ids)] |
| 89 | + unique_sampled_token_ids = pad_sequence(unique_sampled_token_ids_list, batch_first = True) |
| 90 | + |
| 91 | + # calculate the new mask, based on the padding |
| 92 | + |
| 93 | + new_mask = unique_sampled_token_ids != 0 |
| 94 | + |
| 95 | + # CLS token never gets masked out (gets a value of True) |
| 96 | + |
| 97 | + new_mask = F.pad(new_mask, (1, 0), value = True) |
| 98 | + |
| 99 | + # prepend a 0 token id to keep the CLS attention scores |
| 100 | + |
| 101 | + unique_sampled_token_ids = F.pad(unique_sampled_token_ids, (1, 0), value = 0) |
| 102 | + expanded_unique_sampled_token_ids = repeat(unique_sampled_token_ids, 'b n -> b h n', h = heads) |
| 103 | + |
| 104 | + # gather the new attention scores |
| 105 | + |
| 106 | + new_attn = batched_index_select(attn, expanded_unique_sampled_token_ids, dim = 2) |
| 107 | + |
| 108 | + # return the sampled attention scores, new mask (denoting padding), as well as the sampled token indices (for the residual) |
| 109 | + return new_attn, new_mask, unique_sampled_token_ids |
| 110 | + |
| 111 | +# classes |
| 112 | + |
| 113 | +class PreNorm(nn.Module): |
| 114 | + def __init__(self, dim, fn): |
| 115 | + super().__init__() |
| 116 | + self.norm = nn.LayerNorm(dim) |
| 117 | + self.fn = fn |
| 118 | + def forward(self, x, **kwargs): |
| 119 | + return self.fn(self.norm(x), **kwargs) |
| 120 | + |
| 121 | +class FeedForward(nn.Module): |
| 122 | + def __init__(self, dim, hidden_dim, dropout = 0.): |
| 123 | + super().__init__() |
| 124 | + self.net = nn.Sequential( |
| 125 | + nn.Linear(dim, hidden_dim), |
| 126 | + nn.GELU(), |
| 127 | + nn.Dropout(dropout), |
| 128 | + nn.Linear(hidden_dim, dim), |
| 129 | + nn.Dropout(dropout) |
| 130 | + ) |
| 131 | + def forward(self, x): |
| 132 | + return self.net(x) |
| 133 | + |
| 134 | +class Attention(nn.Module): |
| 135 | + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., output_num_tokens = None): |
| 136 | + super().__init__() |
| 137 | + inner_dim = dim_head * heads |
| 138 | + self.heads = heads |
| 139 | + self.scale = dim_head ** -0.5 |
| 140 | + |
| 141 | + self.attend = nn.Softmax(dim = -1) |
| 142 | + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) |
| 143 | + |
| 144 | + self.output_num_tokens = output_num_tokens |
| 145 | + self.ats = AdaptiveTokenSampling(output_num_tokens) if exists(output_num_tokens) else None |
| 146 | + |
| 147 | + self.to_out = nn.Sequential( |
| 148 | + nn.Linear(inner_dim, dim), |
| 149 | + nn.Dropout(dropout) |
| 150 | + ) |
| 151 | + |
| 152 | + def forward(self, x, *, mask): |
| 153 | + num_tokens = x.shape[1] |
| 154 | + |
| 155 | + qkv = self.to_qkv(x).chunk(3, dim = -1) |
| 156 | + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) |
| 157 | + |
| 158 | + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale |
| 159 | + |
| 160 | + if exists(mask): |
| 161 | + dots_mask = rearrange(mask, 'b i -> b 1 i 1') * rearrange(mask, 'b j -> b 1 1 j') |
| 162 | + mask_value = -torch.finfo(dots.dtype).max |
| 163 | + dots = dots.masked_fill(~dots_mask, mask_value) |
| 164 | + |
| 165 | + attn = self.attend(dots) |
| 166 | + |
| 167 | + sampled_token_ids = None |
| 168 | + |
| 169 | + # if adaptive token sampling is enabled |
| 170 | + # and number of tokens is greater than the number of output tokens |
| 171 | + if exists(self.output_num_tokens) and (num_tokens - 1) > self.output_num_tokens: |
| 172 | + attn, mask, sampled_token_ids = self.ats(attn, v, mask = mask) |
| 173 | + |
| 174 | + out = torch.matmul(attn, v) |
| 175 | + out = rearrange(out, 'b h n d -> b n (h d)') |
| 176 | + |
| 177 | + return self.to_out(out), mask, sampled_token_ids |
| 178 | + |
| 179 | +class Transformer(nn.Module): |
| 180 | + def __init__(self, dim, depth, max_tokens_per_depth, heads, dim_head, mlp_dim, dropout = 0.): |
| 181 | + super().__init__() |
| 182 | + assert len(max_tokens_per_depth) == depth, 'max_tokens_per_depth must be a tuple of length that is equal to the depth of the transformer' |
| 183 | + assert sorted(max_tokens_per_depth, reverse = True) == list(max_tokens_per_depth), 'max_tokens_per_depth must be in decreasing order' |
| 184 | + assert min(max_tokens_per_depth) > 0, 'max_tokens_per_depth must have at least 1 token at any layer' |
| 185 | + |
| 186 | + self.layers = nn.ModuleList([]) |
| 187 | + for _, output_num_tokens in zip(range(depth), max_tokens_per_depth): |
| 188 | + self.layers.append(nn.ModuleList([ |
| 189 | + PreNorm(dim, Attention(dim, output_num_tokens = output_num_tokens, heads = heads, dim_head = dim_head, dropout = dropout)), |
| 190 | + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) |
| 191 | + ])) |
| 192 | + |
| 193 | + def forward(self, x): |
| 194 | + b, n, device = *x.shape[:2], x.device |
| 195 | + |
| 196 | + # use mask to keep track of the paddings when sampling tokens |
| 197 | + # as the duplicates (when sampling) are just removed, as mentioned in the paper |
| 198 | + mask = torch.ones((b, n), device = device, dtype = torch.bool) |
| 199 | + |
| 200 | + token_ids = torch.arange(n, device = device) |
| 201 | + token_ids = repeat(token_ids, 'n -> b n', b = b) |
| 202 | + |
| 203 | + for attn, ff in self.layers: |
| 204 | + attn_out, mask, sampled_token_ids = attn(x, mask = mask) |
| 205 | + |
| 206 | + # when token sampling, one needs to then gather the residual tokens with the sampled token ids |
| 207 | + if exists(sampled_token_ids): |
| 208 | + x = batched_index_select(x, sampled_token_ids, dim = 1) |
| 209 | + token_ids = batched_index_select(token_ids, sampled_token_ids, dim = 1) |
| 210 | + |
| 211 | + x = x + attn_out |
| 212 | + |
| 213 | + x = ff(x) + x |
| 214 | + |
| 215 | + return x, token_ids |
| 216 | + |
| 217 | +class ViT(nn.Module): |
| 218 | + def __init__(self, *, image_size, patch_size, num_classes, dim, depth, max_tokens_per_depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): |
| 219 | + super().__init__() |
| 220 | + image_height, image_width = pair(image_size) |
| 221 | + patch_height, patch_width = pair(patch_size) |
| 222 | + |
| 223 | + assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' |
| 224 | + |
| 225 | + num_patches = (image_height // patch_height) * (image_width // patch_width) |
| 226 | + patch_dim = channels * patch_height * patch_width |
| 227 | + |
| 228 | + self.to_patch_embedding = nn.Sequential( |
| 229 | + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), |
| 230 | + nn.Linear(patch_dim, dim), |
| 231 | + ) |
| 232 | + |
| 233 | + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) |
| 234 | + self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) |
| 235 | + self.dropout = nn.Dropout(emb_dropout) |
| 236 | + |
| 237 | + self.transformer = Transformer(dim, depth, max_tokens_per_depth, heads, dim_head, mlp_dim, dropout) |
| 238 | + |
| 239 | + self.mlp_head = nn.Sequential( |
| 240 | + nn.LayerNorm(dim), |
| 241 | + nn.Linear(dim, num_classes) |
| 242 | + ) |
| 243 | + |
| 244 | + def forward(self, img, return_sampled_token_ids = False): |
| 245 | + x = self.to_patch_embedding(img) |
| 246 | + b, n, _ = x.shape |
| 247 | + |
| 248 | + cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) |
| 249 | + x = torch.cat((cls_tokens, x), dim=1) |
| 250 | + x += self.pos_embedding[:, :(n + 1)] |
| 251 | + x = self.dropout(x) |
| 252 | + |
| 253 | + x, token_ids = self.transformer(x) |
| 254 | + |
| 255 | + logits = self.mlp_head(x[:, 0]) |
| 256 | + |
| 257 | + if return_sampled_token_ids: |
| 258 | + # remove CLS token and decrement by 1 to make -1 the padding |
| 259 | + token_ids = token_ids[:, 1:] - 1 |
| 260 | + return logits, token_ids |
| 261 | + |
| 262 | + return logits |
0 commit comments