@@ -311,9 +311,9 @@ def test_resnext50_32x4d(self):
311311
312312 def test_linear_inference (self ):
313313 class L (torch .nn .Module ):
314- def __init__ (self , in_f , out_f ):
314+ def __init__ (self , in_f , out_f , bias ):
315315 super (L , self ).__init__ ()
316- self .linear = torch .nn .Linear (in_f , out_f )
316+ self .linear = torch .nn .Linear (in_f , out_f , bias = bias )
317317
318318 def forward (self , x ):
319319 return self .linear (x )
@@ -325,7 +325,7 @@ def forward(self, x):
325325 options = itertools .product ([True , False ], input_shapes )
326326 for bias , x_shape in options :
327327 x = torch .randn (x_shape , dtype = torch .float32 )
328- model = L (in_features , out_features ).float ().eval ()
328+ model = L (in_features , out_features , bias ).float ().eval ()
329329 for dtype in [torch .float32 , torch .bfloat16 ]:
330330 x1 = x .clone ().requires_grad_ ()
331331 x2 = x .clone ().requires_grad_ ()
@@ -355,7 +355,7 @@ def test_linear_training(self):
355355 for out_features , bias , x_shape in options :
356356 in_features = x_shape [- 1 ]
357357 x = torch .randn (x_shape , dtype = torch .float32 )
358- model = torch .nn .Linear (in_features , out_features ).float ().train ()
358+ model = torch .nn .Linear (in_features , out_features , bias = bias ).float ().train ()
359359 for dtype in [torch .float32 , torch .bfloat16 ]:
360360 x1 = x .clone ().requires_grad_ ()
361361 x2 = x .clone ().requires_grad_ ()
0 commit comments