Skip to content

Commit 365b4d9

Browse files
committed
add adaptive token sampling paper
1 parent 79c864d commit 365b4d9

File tree

4 files changed

+307
-1
lines changed

4 files changed

+307
-1
lines changed

README.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,39 @@ for _ in range(100):
679679
torch.save(model.state_dict(), './pretrained-net.pt')
680680
```
681681

682+
## Adaptive Token Sampling
683+
684+
<img src="./images/ats.png" width="400px"></img>
685+
686+
This <a href="https://arxiv.org/abs/2111.15667">paper</a> proposes to use the CLS attention scores, re-weighed by the norms of the value heads, as means to discard unimportant tokens at different layers.
687+
688+
```python
689+
import torch
690+
from vit_pytorch.ats_vit import ViT
691+
692+
v = ViT(
693+
image_size = 256,
694+
patch_size = 16,
695+
num_classes = 1000,
696+
dim = 1024,
697+
depth = 6,
698+
max_tokens_per_depth = (256, 128, 64, 32, 16, 8), # a tuple that denotes the maximum number of tokens that any given layer should have. if the layer has greater than this amount, it will undergo adaptive token sampling
699+
heads = 16,
700+
mlp_dim = 2048,
701+
dropout = 0.1,
702+
emb_dropout = 0.1
703+
)
704+
705+
img = torch.randn(4, 3, 256, 256)
706+
707+
preds = v(img) # (1, 1000)
708+
709+
# you can also get a list of the final sampled patch ids
710+
# a value of -1 denotes padding
711+
712+
preds, token_ids = v(img, return_sampled_token_ids = True) # (1, 1000), (1, <=8)
713+
```
714+
682715
## Dino
683716

684717
<img src="./images/dino.png" width="350px"></img>
@@ -1119,6 +1152,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
11191152
}
11201153
```
11211154

1155+
```bibtex
1156+
@misc{fayyaz2021ats,
1157+
title = {ATS: Adaptive Token Sampling For Efficient Vision Transformers},
1158+
author = {Mohsen Fayyaz and Soroush Abbasi Kouhpayegani and Farnoush Rezaei Jafari and Eric Sommerlade and Hamid Reza Vaezi Joze and Hamed Pirsiavash and Juergen Gall},
1159+
year = {2021},
1160+
eprint = {2111.15667},
1161+
archivePrefix = {arXiv},
1162+
primaryClass = {cs.CV}
1163+
}
1164+
```
1165+
11221166
```bibtex
11231167
@misc{vaswani2017attention,
11241168
title = {Attention Is All You Need},

images/ats.png

198 KB
Loading

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.24.2',
6+
version = '0.24.3',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
author = 'Phil Wang',

vit_pytorch/ats_vit.py

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from torch.nn.utils.rnn import pad_sequence
4+
from torch import nn, einsum
5+
6+
from einops import rearrange, repeat
7+
from einops.layers.torch import Rearrange
8+
9+
# helpers
10+
11+
def exists(val):
12+
return val is not None
13+
14+
def pair(t):
15+
return t if isinstance(t, tuple) else (t, t)
16+
17+
# adaptive token sampling functions and classes
18+
19+
def log(t, eps = 1e-6):
20+
return torch.log(t + eps)
21+
22+
def sample_gumbel(shape, device, dtype, eps = 1e-6):
23+
u = torch.empty(shape, device = device, dtype = dtype).uniform_(0, 1)
24+
return -log(-log(u, eps), eps)
25+
26+
def batched_index_select(values, indices, dim = 1):
27+
value_dims = values.shape[(dim + 1):]
28+
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
29+
indices = indices[(..., *((None,) * len(value_dims)))]
30+
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
31+
value_expand_len = len(indices_shape) - (dim + 1)
32+
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
33+
34+
value_expand_shape = [-1] * len(values.shape)
35+
expand_slice = slice(dim, (dim + value_expand_len))
36+
value_expand_shape[expand_slice] = indices.shape[expand_slice]
37+
values = values.expand(*value_expand_shape)
38+
39+
dim += value_expand_len
40+
return values.gather(dim, indices)
41+
42+
class AdaptiveTokenSampling(nn.Module):
43+
def __init__(self, output_num_tokens, eps = 1e-6):
44+
super().__init__()
45+
self.eps = eps
46+
self.output_num_tokens = output_num_tokens
47+
48+
def forward(self, attn, value, mask):
49+
heads, output_num_tokens, eps, device, dtype = attn.shape[1], self.output_num_tokens, self.eps, attn.device, attn.dtype
50+
51+
# first get the attention values for CLS token to all other tokens
52+
53+
cls_attn = attn[..., 0, 1:]
54+
55+
# calculate the norms of the values, for weighting the scores, as described in the paper
56+
57+
value_norms = value[..., 1:, :].norm(dim = -1)
58+
59+
# weigh the attention scores by the norm of the values, sum across all heads
60+
61+
cls_attn = einsum('b h n, b h n -> b n', cls_attn, value_norms)
62+
63+
# normalize to 1
64+
65+
normed_cls_attn = cls_attn / (cls_attn.sum(dim = -1, keepdim = True) + eps)
66+
67+
# instead of using inverse transform sampling, going to invert the softmax and use gumbel-max sampling instead
68+
69+
pseudo_logits = log(normed_cls_attn)
70+
71+
# mask out pseudo logits for gumbel-max sampling
72+
73+
mask_without_cls = mask[:, 1:]
74+
mask_value = -torch.finfo(attn.dtype).max / 2
75+
pseudo_logits = pseudo_logits.masked_fill(~mask_without_cls, mask_value)
76+
77+
# expand k times, k being the adaptive sampling number
78+
79+
pseudo_logits = repeat(pseudo_logits, 'b n -> b k n', k = output_num_tokens)
80+
pseudo_logits = pseudo_logits + sample_gumbel(pseudo_logits.shape, device = device, dtype = dtype)
81+
82+
# gumble-max and add one to reserve 0 for padding / mask
83+
84+
sampled_token_ids = pseudo_logits.argmax(dim = -1) + 1
85+
86+
# calculate unique using torch.unique and then pad the sequence from the right
87+
88+
unique_sampled_token_ids_list = [torch.unique(t, sorted = True) for t in torch.unbind(sampled_token_ids)]
89+
unique_sampled_token_ids = pad_sequence(unique_sampled_token_ids_list, batch_first = True)
90+
91+
# calculate the new mask, based on the padding
92+
93+
new_mask = unique_sampled_token_ids != 0
94+
95+
# CLS token never gets masked out (gets a value of True)
96+
97+
new_mask = F.pad(new_mask, (1, 0), value = True)
98+
99+
# prepend a 0 token id to keep the CLS attention scores
100+
101+
unique_sampled_token_ids = F.pad(unique_sampled_token_ids, (1, 0), value = 0)
102+
expanded_unique_sampled_token_ids = repeat(unique_sampled_token_ids, 'b n -> b h n', h = heads)
103+
104+
# gather the new attention scores
105+
106+
new_attn = batched_index_select(attn, expanded_unique_sampled_token_ids, dim = 2)
107+
108+
# return the sampled attention scores, new mask (denoting padding), as well as the sampled token indices (for the residual)
109+
return new_attn, new_mask, unique_sampled_token_ids
110+
111+
# classes
112+
113+
class PreNorm(nn.Module):
114+
def __init__(self, dim, fn):
115+
super().__init__()
116+
self.norm = nn.LayerNorm(dim)
117+
self.fn = fn
118+
def forward(self, x, **kwargs):
119+
return self.fn(self.norm(x), **kwargs)
120+
121+
class FeedForward(nn.Module):
122+
def __init__(self, dim, hidden_dim, dropout = 0.):
123+
super().__init__()
124+
self.net = nn.Sequential(
125+
nn.Linear(dim, hidden_dim),
126+
nn.GELU(),
127+
nn.Dropout(dropout),
128+
nn.Linear(hidden_dim, dim),
129+
nn.Dropout(dropout)
130+
)
131+
def forward(self, x):
132+
return self.net(x)
133+
134+
class Attention(nn.Module):
135+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., output_num_tokens = None):
136+
super().__init__()
137+
inner_dim = dim_head * heads
138+
self.heads = heads
139+
self.scale = dim_head ** -0.5
140+
141+
self.attend = nn.Softmax(dim = -1)
142+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
143+
144+
self.output_num_tokens = output_num_tokens
145+
self.ats = AdaptiveTokenSampling(output_num_tokens) if exists(output_num_tokens) else None
146+
147+
self.to_out = nn.Sequential(
148+
nn.Linear(inner_dim, dim),
149+
nn.Dropout(dropout)
150+
)
151+
152+
def forward(self, x, *, mask):
153+
num_tokens = x.shape[1]
154+
155+
qkv = self.to_qkv(x).chunk(3, dim = -1)
156+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
157+
158+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
159+
160+
if exists(mask):
161+
dots_mask = rearrange(mask, 'b i -> b 1 i 1') * rearrange(mask, 'b j -> b 1 1 j')
162+
mask_value = -torch.finfo(dots.dtype).max
163+
dots = dots.masked_fill(~dots_mask, mask_value)
164+
165+
attn = self.attend(dots)
166+
167+
sampled_token_ids = None
168+
169+
# if adaptive token sampling is enabled
170+
# and number of tokens is greater than the number of output tokens
171+
if exists(self.output_num_tokens) and (num_tokens - 1) > self.output_num_tokens:
172+
attn, mask, sampled_token_ids = self.ats(attn, v, mask = mask)
173+
174+
out = torch.matmul(attn, v)
175+
out = rearrange(out, 'b h n d -> b n (h d)')
176+
177+
return self.to_out(out), mask, sampled_token_ids
178+
179+
class Transformer(nn.Module):
180+
def __init__(self, dim, depth, max_tokens_per_depth, heads, dim_head, mlp_dim, dropout = 0.):
181+
super().__init__()
182+
assert len(max_tokens_per_depth) == depth, 'max_tokens_per_depth must be a tuple of length that is equal to the depth of the transformer'
183+
assert sorted(max_tokens_per_depth, reverse = True) == list(max_tokens_per_depth), 'max_tokens_per_depth must be in decreasing order'
184+
assert min(max_tokens_per_depth) > 0, 'max_tokens_per_depth must have at least 1 token at any layer'
185+
186+
self.layers = nn.ModuleList([])
187+
for _, output_num_tokens in zip(range(depth), max_tokens_per_depth):
188+
self.layers.append(nn.ModuleList([
189+
PreNorm(dim, Attention(dim, output_num_tokens = output_num_tokens, heads = heads, dim_head = dim_head, dropout = dropout)),
190+
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
191+
]))
192+
193+
def forward(self, x):
194+
b, n, device = *x.shape[:2], x.device
195+
196+
# use mask to keep track of the paddings when sampling tokens
197+
# as the duplicates (when sampling) are just removed, as mentioned in the paper
198+
mask = torch.ones((b, n), device = device, dtype = torch.bool)
199+
200+
token_ids = torch.arange(n, device = device)
201+
token_ids = repeat(token_ids, 'n -> b n', b = b)
202+
203+
for attn, ff in self.layers:
204+
attn_out, mask, sampled_token_ids = attn(x, mask = mask)
205+
206+
# when token sampling, one needs to then gather the residual tokens with the sampled token ids
207+
if exists(sampled_token_ids):
208+
x = batched_index_select(x, sampled_token_ids, dim = 1)
209+
token_ids = batched_index_select(token_ids, sampled_token_ids, dim = 1)
210+
211+
x = x + attn_out
212+
213+
x = ff(x) + x
214+
215+
return x, token_ids
216+
217+
class ViT(nn.Module):
218+
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, max_tokens_per_depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
219+
super().__init__()
220+
image_height, image_width = pair(image_size)
221+
patch_height, patch_width = pair(patch_size)
222+
223+
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
224+
225+
num_patches = (image_height // patch_height) * (image_width // patch_width)
226+
patch_dim = channels * patch_height * patch_width
227+
228+
self.to_patch_embedding = nn.Sequential(
229+
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
230+
nn.Linear(patch_dim, dim),
231+
)
232+
233+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
234+
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
235+
self.dropout = nn.Dropout(emb_dropout)
236+
237+
self.transformer = Transformer(dim, depth, max_tokens_per_depth, heads, dim_head, mlp_dim, dropout)
238+
239+
self.mlp_head = nn.Sequential(
240+
nn.LayerNorm(dim),
241+
nn.Linear(dim, num_classes)
242+
)
243+
244+
def forward(self, img, return_sampled_token_ids = False):
245+
x = self.to_patch_embedding(img)
246+
b, n, _ = x.shape
247+
248+
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
249+
x = torch.cat((cls_tokens, x), dim=1)
250+
x += self.pos_embedding[:, :(n + 1)]
251+
x = self.dropout(x)
252+
253+
x, token_ids = self.transformer(x)
254+
255+
logits = self.mlp_head(x[:, 0])
256+
257+
if return_sampled_token_ids:
258+
# remove CLS token and decrement by 1 to make -1 the padding
259+
token_ids = token_ids[:, 1:] - 1
260+
return logits, token_ids
261+
262+
return logits

0 commit comments

Comments
 (0)