Skip to content

Commit dcbfe91

Browse files
authored
WOQ bug fix: at::narrow is always called on result (#2339)
1 parent 710080f commit dcbfe91

File tree

1 file changed

+6
-17
lines changed

1 file changed

+6
-17
lines changed

csrc/cpu/jit/cpu/kernels/LinearWoqPacked.cpp

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -263,23 +263,7 @@ at::Tensor run(ContextLinearWoq& context, const at::Tensor& input) {
263263
w_k,
264264
" respectively.");
265265
auto input_ = input.contiguous();
266-
if (context.weight_shape_[0] != context.at_weight_.size(0)) {
267-
auto res = woq_linear_kernel(
268-
input_,
269-
context.at_weight_,
270-
context.scales_list_,
271-
context.zero_points_list_,
272-
context.bias_list_,
273-
context.is_int4_,
274-
context.group_size_,
275-
context.lowp_mode_,
276-
context.num_concats_,
277-
context.act_quant_mode_);
278-
// weight shape is [N by K], output shape is [M by N] or [batch by M by N]
279-
int64_t N = context.weight_shape_[0];
280-
return at::narrow(res, /*dim*/ -1, /*start*/ 0, /*end*/ N);
281-
}
282-
return woq_linear_kernel(
266+
auto res = woq_linear_kernel(
283267
input_,
284268
context.at_weight_,
285269
context.scales_list_,
@@ -290,6 +274,11 @@ at::Tensor run(ContextLinearWoq& context, const at::Tensor& input) {
290274
context.lowp_mode_,
291275
context.num_concats_,
292276
context.act_quant_mode_);
277+
if (res.size(-1) != context.weight_shape_[0]) {
278+
int64_t N = context.weight_shape_[0];
279+
return at::narrow(res, /*dim*/ -1, /*start*/ 0, /*end*/ N);
280+
}
281+
return res;
293282
}
294283

295284
// Called by IpexWoqLinearOpContext::run_eltwise

0 commit comments

Comments
 (0)