-
Notifications
You must be signed in to change notification settings - Fork 136
Open
Description
In class PrimaryCaps, the forward function is defined as the following
def forward(self, x):
u = [capsule(x) for capsule in self.capsules]
u = torch.stack(u, dim=1)
u = u.view(x.size(0), 32 * 6 * 6, -1)
return self.squash(u)
by u = torch.stack(u, dim=1), the shape of u becomes [100, 8, 32, 6, 6]. However, u.view(x.size(0), 32 * 6 * 6, -1) just flattens the wrong dimensions. I think the correct code should be u = torch.stack(u, dim=-1).
Is is kind of a bug? or do I misunderstand it?
Metadata
Metadata
Assignees
Labels
No labels