Skip to content

Commit c2b2db2

Browse files
committed
fix window size of none for scalable vit for rectangular images
1 parent 719048d commit c2b2db2

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.28.1',
6+
version = '0.28.2',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
author = 'Phil Wang',

vit_pytorch/scalable_vit.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ def __init__(
156156
def forward(self, x):
157157
height, width, heads, wsz = *x.shape[-2:], self.heads, self.window_size
158158

159-
wsz = default(wsz, height) # take height as window size if not given
160-
assert (height % wsz) == 0 and (width % wsz) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz})'
159+
wsz_h, wsz_w = default(wsz, height), default(wsz, width)
160+
assert (height % wsz_h) == 0 and (width % wsz_w) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz_h}, {wsz_w})'
161161

162162
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
163163

@@ -167,7 +167,7 @@ def forward(self, x):
167167

168168
# divide into window (and split out heads) for efficient self attention
169169

170-
q, k, v = map(lambda t: rearrange(t, 'b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d', h = heads, w1 = wsz, w2 = wsz), (q, k, v))
170+
q, k, v = map(lambda t: rearrange(t, 'b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d', h = heads, w1 = wsz_h, w2 = wsz_w), (q, k, v))
171171

172172
# similarity
173173

@@ -183,7 +183,7 @@ def forward(self, x):
183183

184184
# reshape the windows back to full feature map (and merge heads)
185185

186-
out = rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
186+
out = rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz_h, y = width // wsz_w, w1 = wsz_h, w2 = wsz_w)
187187

188188
# add LIM output
189189

0 commit comments

Comments
 (0)