You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
assert (height%wsz_h) ==0and (width%wsz_w) ==0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz_h}, {wsz_w})'
161
161
162
162
q, k, v=self.to_q(x), self.to_k(x), self.to_v(x)
163
163
@@ -167,7 +167,7 @@ def forward(self, x):
167
167
168
168
# divide into window (and split out heads) for efficient self attention
169
169
170
-
q, k, v=map(lambdat: rearrange(t, 'b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d', h=heads, w1=wsz, w2=wsz), (q, k, v))
170
+
q, k, v=map(lambdat: rearrange(t, 'b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d', h=heads, w1=wsz_h, w2=wsz_w), (q, k, v))
171
171
172
172
# similarity
173
173
@@ -183,7 +183,7 @@ def forward(self, x):
183
183
184
184
# reshape the windows back to full feature map (and merge heads)
185
185
186
-
out=rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x=height//wsz, y=width//wsz, w1=wsz, w2=wsz)
186
+
out=rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x=height//wsz_h, y=width//wsz_w, w1=wsz_h, w2=wsz_w)
0 commit comments