Skip to content

Commit d1163ba

Browse files
committed
[Inductor][float8] Register qconv-binary fusion pass for float8
1 parent add0e37 commit d1163ba

File tree

2 files changed

+63
-20
lines changed

2 files changed

+63
-20
lines changed

test/quantization/pt2e/test_x86inductor_fusion.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,7 @@ def test_qconv2d_silu_fp8_mixed_bf16_cpu(self):
770770
)
771771

772772
def _qconv2d_add_test_helper(
773-
self, device="cpu", use_relu=False, int8_mixed_bf16=False
773+
self, device="cpu", use_relu=False, mixed_bf16=False, is_fp8=False
774774
):
775775
r"""
776776
This testcase will quantize a Conv2d->Add pattern as:
@@ -844,11 +844,12 @@ def matcher_check_fn():
844844
(v,),
845845
matcher_check_fn,
846846
check_quantization=True,
847-
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32,
847+
check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32,
848+
is_fp8=is_fp8,
848849
)
849850

850851
def _qconv2d_add_test_helper2(
851-
self, device="cpu", use_relu=False, int8_mixed_bf16=False
852+
self, device="cpu", use_relu=False, mixed_bf16=False, is_fp8=False
852853
):
853854
r"""
854855
This testcase will quantize two Conv2d->Add patterns as:
@@ -907,8 +908,11 @@ def forward(self, x, x2, x3):
907908
res = self.relu2(res)
908909
return res
909910

911+
add_fn_list = quantization_add_fn_list
912+
if not is_fp8:
913+
add_fn_list = add_fn_list + quantization_inplace_add_fn_list
910914
for add_fn, swap_inputs in itertools.product(
911-
quantization_add_fn_list + quantization_inplace_add_fn_list, [False, True]
915+
add_fn_list, [False, True]
912916
):
913917
mod = M(add_fn, use_relu, swap_inputs).eval().to(device=device)
914918
x = torch.randn(
@@ -941,7 +945,8 @@ def matcher_check_fn():
941945
(x, x2, x3),
942946
matcher_check_fn,
943947
check_quantization=True,
944-
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32,
948+
check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32,
949+
is_fp8=is_fp8,
945950
)
946951

947952
@skipIfNoDynamoSupport
@@ -950,25 +955,55 @@ def test_qconv2d_add_cpu(self):
950955
self._qconv2d_add_test_helper()
951956
self._qconv2d_add_test_helper2()
952957

958+
@skipIfNoDynamoSupport
959+
@skipIfNoONEDNN
960+
@skipIfNoFloat8Support
961+
def test_qconv2d_add_fp8_cpu(self):
962+
self._qconv2d_add_test_helper(is_fp8=True)
963+
self._qconv2d_add_test_helper2(is_fp8=True)
964+
953965
@skipIfNoDynamoSupport
954966
@skipIfNoONEDNNBF16
955967
@skipIfNoONEDNN
956968
def test_qconv2d_add_int8_mixed_bf16(self):
957-
self._qconv2d_add_test_helper(int8_mixed_bf16=True)
958-
self._qconv2d_add_test_helper2(int8_mixed_bf16=True)
969+
self._qconv2d_add_test_helper(mixed_bf16=True)
970+
self._qconv2d_add_test_helper2(mixed_bf16=True)
971+
972+
@skipIfNoDynamoSupport
973+
@skipIfNoONEDNNBF16
974+
@skipIfNoONEDNN
975+
@skipIfNoFloat8Support
976+
def test_qconv2d_add_fp8_mixed_bf16(self):
977+
self._qconv2d_add_test_helper(mixed_bf16=True, is_fp8=True)
978+
self._qconv2d_add_test_helper2(mixed_bf16=True, is_fp8=True)
959979

960980
@skipIfNoDynamoSupport
961981
@skipIfNoONEDNN
962982
def test_qconv2d_add_relu_cpu(self):
963983
self._qconv2d_add_test_helper(use_relu=True)
964984
self._qconv2d_add_test_helper2(use_relu=True)
965985

986+
@skipIfNoDynamoSupport
987+
@skipIfNoONEDNN
988+
@skipIfNoFloat8Support
989+
def test_qconv2d_add_relu_fp8_cpu(self):
990+
self._qconv2d_add_test_helper(use_relu=True, is_fp8=True)
991+
self._qconv2d_add_test_helper2(use_relu=True, is_fp8=True)
992+
966993
@skipIfNoDynamoSupport
967994
@skipIfNoONEDNNBF16
968995
@skipIfNoONEDNN
969996
def test_qconv2d_add_relu_int8_mixed_bf16(self):
970-
self._qconv2d_add_test_helper(use_relu=True, int8_mixed_bf16=True)
971-
self._qconv2d_add_test_helper2(use_relu=True, int8_mixed_bf16=True)
997+
self._qconv2d_add_test_helper(use_relu=True, mixed_bf16=True)
998+
self._qconv2d_add_test_helper2(use_relu=True, mixed_bf16=True)
999+
1000+
@skipIfNoDynamoSupport
1001+
@skipIfNoONEDNNBF16
1002+
@skipIfNoONEDNN
1003+
@skipIfNoFloat8Support
1004+
def test_qconv2d_add_relu_fp8_mixed_bf16(self):
1005+
self._qconv2d_add_test_helper(use_relu=True, mixed_bf16=True, is_fp8=True)
1006+
self._qconv2d_add_test_helper2(use_relu=True, mixed_bf16=True, is_fp8=True)
9721007

9731008
@skipIfNoDynamoSupport
9741009
@skipIfNoONEDNN

torchao/quantization/pt2e/inductor_passes/x86.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,7 @@ def fn(match):
726726
return False
727727
binary_node_inputs = next(iter(compute_node.users)).args
728728
assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs"
729+
is_fp8 = match.kwargs["x"].meta["val"].dtype is torch.float8_e4m3fn
729730
if output_dtype in [torch.float32, torch.bfloat16]:
730731
extra_input_of_binary_node = None
731732
for arg in binary_node_inputs:
@@ -734,7 +735,7 @@ def fn(match):
734735
break
735736
assert extra_input_of_binary_node is not None
736737
# Extra input of binary node comes from dequant pattern
737-
if extra_input_from_dequant and (
738+
if not is_fp8 and extra_input_from_dequant and (
738739
(not isinstance(extra_input_of_binary_node, torch.fx.Node))
739740
or (
740741
extra_input_of_binary_node.target
@@ -3228,37 +3229,44 @@ def _register_qconv_unary_fusion():
32283229

32293230

32303231
def _register_qconv_binary_fusion():
3231-
for int8_mixed_bf16_with_inplace_add in [False, True]:
3232+
for int8_mixed_bf16_with_inplace_add, x_scale_zp_are_tensors in itertools.product([False, True], [False, True]):
3233+
qconv_binary_op = (
3234+
torch.ops.onednn.qconv2d_pointwise.binary_tensor
3235+
if x_scale_zp_are_tensors
3236+
else torch.ops.onednn.qconv2d_pointwise.binary
3237+
)
32323238
# Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output
32333239
swap_binary_inputs_list = [False, True]
32343240
binary_replace_patterns = {}
3235-
for swap_inputs in swap_binary_inputs_list:
3241+
for swap_inputs, is_fp8 in itertools.product(swap_binary_inputs_list, [False, True]):
32363242
binary_replace_patterns.update(
32373243
{
32383244
PostOpAttr(
32393245
"sum", 1.0, "none", [], ""
32403246
): generate_pattern_with_output_quant(
32413247
generate_pattern_with_binary(
32423248
aten.add.Tensor,
3243-
get_qconv_pt2e_pattern(users=1),
3249+
get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1),
32443250
dequantize_accum_pattern,
32453251
int8_mixed_bf16_with_inplace_add,
32463252
swap_inputs=swap_inputs,
32473253
),
3254+
is_fp8=is_fp8,
32483255
),
32493256
PostOpAttr(
32503257
"sum", 1.0, "relu", [], ""
32513258
): generate_pattern_with_output_quant(
32523259
generate_pattern_with_unary(
32533260
generate_pattern_with_binary(
32543261
aten.add.Tensor,
3255-
get_qconv_pt2e_pattern(users=1),
3262+
get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1),
32563263
dequantize_accum_pattern,
32573264
int8_mixed_bf16_with_inplace_add,
32583265
swap_inputs=swap_inputs,
32593266
),
32603267
aten.relu.default,
32613268
),
3269+
is_fp8=is_fp8,
32623270
),
32633271
}
32643272
)
@@ -3267,7 +3275,7 @@ def _register_qconv_binary_fusion():
32673275
_register_qconv_post_op_fusion_pass(
32683276
patterns,
32693277
3, # pass_number
3270-
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
3278+
qconv_binary_op, # computation_op
32713279
binary_unary_attr, # binary_unary_attr
32723280
)
32733281

@@ -3279,7 +3287,7 @@ def _register_qconv_binary_fusion():
32793287
PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary(
32803288
generate_pattern_with_binary(
32813289
aten.add.Tensor,
3282-
get_qconv_pt2e_pattern(users=1),
3290+
get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1),
32833291
KeywordArg("accum_after_dequant"),
32843292
int8_mixed_bf16_with_inplace_add,
32853293
swap_inputs=swap_inputs,
@@ -3297,14 +3305,14 @@ def _register_qconv_binary_fusion():
32973305
_register_qconv_post_op_fusion_pass(
32983306
patterns,
32993307
3, # pass_number
3300-
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
3308+
qconv_binary_op, # computation_op
33013309
binary_unary_attr, # binary_unary_attr
33023310
)
33033311
else:
33043312
_register_qconv_post_op_fusion_pass(
33053313
patterns,
33063314
4, # pass_number
3307-
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
3315+
qconv_binary_op, # computation_op
33083316
binary_unary_attr, # binary_unary_attr
33093317
)
33103318

@@ -3317,7 +3325,7 @@ def _register_qconv_binary_fusion():
33173325
"sum", 1.0, "none", [], ""
33183326
): generate_pattern_with_binary(
33193327
aten.add.Tensor,
3320-
get_qconv_pt2e_pattern(users=1),
3328+
get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1),
33213329
KeywordArg("accum_after_dequant"),
33223330
int8_mixed_bf16_with_inplace_add,
33233331
swap_inputs=swap_inputs,
@@ -3332,7 +3340,7 @@ def _register_qconv_binary_fusion():
33323340
_register_qconv_post_op_fusion_pass(
33333341
patterns,
33343342
4 if int8_mixed_bf16_with_inplace_add else 5, # pass_number
3335-
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
3343+
qconv_binary_op, # computation_op
33363344
binary_unary_attr, # binary_unary_attr
33373345
)
33383346

0 commit comments

Comments
 (0)