diff --git a/timm/layers/layer_scale.py b/timm/layers/layer_scale.py index 123073bcd1..aee0ae19cc 100644 --- a/timm/layers/layer_scale.py +++ b/timm/layers/layer_scale.py @@ -14,13 +14,14 @@ def __init__( dtype=None, ) -> None: super().__init__() + self.init_values = init_values self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.empty(dim, device=device, dtype=dtype)) + self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) self.reset_parameters() def reset_parameters(self): - torch.nn.init.ones_(self.gamma) + torch.nn.init.constant_(self.gamma, self.init_values) def forward(self, x: torch.Tensor) -> torch.Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma @@ -38,13 +39,14 @@ def __init__( dtype=None, ): super().__init__() + self.init_values = init_values self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.empty(dim, device=device, dtype=dtype)) + self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) self.reset_parameters() def reset_parameters(self): - torch.nn.init.ones_(self.gamma) + torch.nn.init.constant_(self.gamma, self.init_values) def forward(self, x): gamma = self.gamma.view(1, -1, 1, 1)