Skip to content

Commit 3cf5f5b

Browse files
committed
add SimMIM
1 parent c5a4616 commit 3cf5f5b

File tree

4 files changed

+139
-1
lines changed

4 files changed

+139
-1
lines changed

README.md

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
- [RegionViT](#regionvit)
2020
- [NesT](#nest)
2121
- [Masked Autoencoder](#masked-autoencoder)
22+
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
2223
- [Masked Patch Prediction](#masked-patch-prediction)
2324
- [Dino](#dino)
2425
- [Accessing Attention](#accessing-attention)
@@ -519,6 +520,46 @@ img = torch.randn(1, 3, 224, 224)
519520
pred = nest(img) # (1, 1000)
520521
```
521522

523+
## Simple Masked Image Modeling
524+
525+
<img src="./images/simmim.png" width="400px"/>
526+
527+
This <a href="https://arxiv.org/abs/2111.09886">paper</a> proposes a simple masked image modeling (SimMIM) scheme, using only a linear projection off the masked tokens into pixel space followed by an L1 loss with the pixel values of the masked patches. Results are competitive with other more complicated approaches.
528+
529+
You can use this as follows
530+
531+
```python
532+
import torch
533+
from vit_pytorch import ViT
534+
from vit_pytorch.simmim import SimMIM
535+
536+
v = ViT(
537+
image_size = 256,
538+
patch_size = 32,
539+
num_classes = 1000,
540+
dim = 1024,
541+
depth = 6,
542+
heads = 8,
543+
mlp_dim = 2048
544+
)
545+
546+
mim = SimMIM(
547+
encoder = v,
548+
masking_ratio = 0.5 # they found 50% to yield the best results
549+
)
550+
551+
images = torch.randn(8, 3, 256, 256)
552+
553+
loss = mim(images)
554+
loss.backward()
555+
556+
# that's all!
557+
# do the above in a for loop many times with a lot of images and your vision transformer will learn
558+
559+
torch.save(v.state_dict(), './trained-vit.pt')
560+
```
561+
562+
522563
## Masked Autoencoder
523564

524565
<img src="./images/mae.png" width="400px"/>
@@ -1026,6 +1067,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
10261067
}
10271068
```
10281069

1070+
```bibtex
1071+
@misc{xie2021simmim,
1072+
title = {SimMIM: A Simple Framework for Masked Image Modeling},
1073+
author = {Zhenda Xie and Zheng Zhang and Yue Cao and Yutong Lin and Jianmin Bao and Zhuliang Yao and Qi Dai and Han Hu},
1074+
year = {2021},
1075+
eprint = {2111.09886},
1076+
archivePrefix = {arXiv},
1077+
primaryClass = {cs.CV}
1078+
}
1079+
```
1080+
10291081
```bibtex
10301082
@misc{vaswani2017attention,
10311083
title = {Attention Is All You Need},

images/simmim.png

365 KB
Loading

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.22.0',
6+
version = '0.23.1',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
author = 'Phil Wang',

vit_pytorch/simmim.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import torch
2+
from torch import nn
3+
import torch.nn.functional as F
4+
from einops import repeat
5+
6+
from vit_pytorch.vit import Transformer
7+
8+
class SimMIM(nn.Module):
9+
def __init__(
10+
self,
11+
*,
12+
encoder,
13+
masking_ratio = 0.5
14+
):
15+
super().__init__()
16+
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
17+
self.masking_ratio = masking_ratio
18+
19+
# extract some hyperparameters and functions from encoder (vision transformer to be trained)
20+
21+
self.encoder = encoder
22+
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
23+
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
24+
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
25+
26+
# simple linear head
27+
28+
self.mask_token = nn.Parameter(torch.randn(encoder_dim))
29+
self.to_pixels = nn.Linear(encoder_dim, pixel_values_per_patch)
30+
31+
def forward(self, img):
32+
device = img.device
33+
34+
# get patches
35+
36+
patches = self.to_patch(img)
37+
batch, num_patches, *_ = patches.shape
38+
39+
# for indexing purposes
40+
41+
batch_range = torch.arange(batch, device = device)[:, None]
42+
43+
# get positions
44+
45+
pos_emb = self.encoder.pos_embedding[:, 1:(num_patches + 1)]
46+
47+
# patch to encoder tokens and add positions
48+
49+
tokens = self.patch_to_emb(patches)
50+
tokens = tokens + pos_emb
51+
52+
# prepare mask tokens
53+
54+
mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_patches)
55+
mask_tokens = mask_tokens + pos_emb
56+
57+
# calculate of patches needed to be masked, and get positions (indices) to be masked
58+
59+
num_masked = int(self.masking_ratio * num_patches)
60+
masked_indices = torch.rand(batch, num_patches, device = device).topk(k = num_masked, dim = -1).indices
61+
masked_bool_mask = torch.zeros((batch, num_patches), device = device).scatter_(-1, masked_indices, 1).bool()
62+
63+
# mask tokens
64+
65+
tokens = torch.where(masked_bool_mask[..., None], mask_tokens, tokens)
66+
67+
# attend with vision transformer
68+
69+
encoded = self.encoder.transformer(tokens)
70+
71+
# get the masked tokens
72+
73+
encoded_mask_tokens = encoded[batch_range, masked_indices]
74+
75+
# small linear projection for predicted pixel values
76+
77+
pred_pixel_values = self.to_pixels(encoded_mask_tokens)
78+
79+
# get the masked patches for the final reconstruction loss
80+
81+
masked_patches = patches[batch_range, masked_indices]
82+
83+
# calculate reconstruction loss
84+
85+
recon_loss = F.l1_loss(pred_pixel_values, masked_patches) / num_masked
86+
return recon_loss

0 commit comments

Comments
 (0)