Skip to content

Commit 5d92d38

Browse files
authored
fix _IPEXLinear with bias=False (#150)
1 parent f93dea0 commit 5d92d38

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

intel_pytorch_extension_py/weight_prepack.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ def __init__(self, dense_module):
133133
self.master_bias = dense_module.master_bias
134134
elif hasattr(dense_module, 'bias_trail'):
135135
self.bias_trail = dense_module.bias_trail
136+
else:
137+
self.register_parameter('bias', None)
136138

137139
def forward(self, x):
138140
return torch.ops.torch_ipex.ipex_linear(

tests/cpu/test_weight_prepack.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -311,9 +311,9 @@ def test_resnext50_32x4d(self):
311311

312312
def test_linear_inference(self):
313313
class L(torch.nn.Module):
314-
def __init__(self, in_f, out_f):
314+
def __init__(self, in_f, out_f, bias):
315315
super(L, self).__init__()
316-
self.linear = torch.nn.Linear(in_f, out_f)
316+
self.linear = torch.nn.Linear(in_f, out_f, bias=bias)
317317

318318
def forward(self, x):
319319
return self.linear(x)
@@ -325,7 +325,7 @@ def forward(self, x):
325325
options = itertools.product([True, False], input_shapes)
326326
for bias, x_shape in options:
327327
x = torch.randn(x_shape, dtype=torch.float32)
328-
model = L(in_features, out_features).float().eval()
328+
model = L(in_features, out_features, bias).float().eval()
329329
for dtype in [torch.float32, torch.bfloat16]:
330330
x1 = x.clone().requires_grad_()
331331
x2 = x.clone().requires_grad_()
@@ -355,7 +355,7 @@ def test_linear_training(self):
355355
for out_features, bias, x_shape in options:
356356
in_features = x_shape[-1]
357357
x = torch.randn(x_shape, dtype=torch.float32)
358-
model = torch.nn.Linear(in_features, out_features).float().train()
358+
model = torch.nn.Linear(in_features, out_features, bias=bias).float().train()
359359
for dtype in [torch.float32, torch.bfloat16]:
360360
x1 = x.clone().requires_grad_()
361361
x2 = x.clone().requires_grad_()

0 commit comments

Comments
 (0)