@@ -726,6 +726,7 @@ def fn(match):
726726 return False
727727 binary_node_inputs = next (iter (compute_node .users )).args
728728 assert len (binary_node_inputs ) == 2 , "Expects binary node with 2 inputs"
729+ is_fp8 = match .kwargs ["x" ].meta ["val" ].dtype is torch .float8_e4m3fn
729730 if output_dtype in [torch .float32 , torch .bfloat16 ]:
730731 extra_input_of_binary_node = None
731732 for arg in binary_node_inputs :
@@ -734,7 +735,7 @@ def fn(match):
734735 break
735736 assert extra_input_of_binary_node is not None
736737 # Extra input of binary node comes from dequant pattern
737- if extra_input_from_dequant and (
738+ if not is_fp8 and extra_input_from_dequant and (
738739 (not isinstance (extra_input_of_binary_node , torch .fx .Node ))
739740 or (
740741 extra_input_of_binary_node .target
@@ -3228,37 +3229,44 @@ def _register_qconv_unary_fusion():
32283229
32293230
32303231def _register_qconv_binary_fusion ():
3231- for int8_mixed_bf16_with_inplace_add in [False , True ]:
3232+ for int8_mixed_bf16_with_inplace_add , x_scale_zp_are_tensors in itertools .product ([False , True ], [False , True ]):
3233+ qconv_binary_op = (
3234+ torch .ops .onednn .qconv2d_pointwise .binary_tensor
3235+ if x_scale_zp_are_tensors
3236+ else torch .ops .onednn .qconv2d_pointwise .binary
3237+ )
32323238 # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output
32333239 swap_binary_inputs_list = [False , True ]
32343240 binary_replace_patterns = {}
3235- for swap_inputs in swap_binary_inputs_list :
3241+ for swap_inputs , is_fp8 in itertools . product ( swap_binary_inputs_list , [ False , True ]) :
32363242 binary_replace_patterns .update (
32373243 {
32383244 PostOpAttr (
32393245 "sum" , 1.0 , "none" , [], ""
32403246 ): generate_pattern_with_output_quant (
32413247 generate_pattern_with_binary (
32423248 aten .add .Tensor ,
3243- get_qconv_pt2e_pattern (users = 1 ),
3249+ get_qconv_pt2e_pattern (x_scale_zp_are_tensors , 1 ),
32443250 dequantize_accum_pattern ,
32453251 int8_mixed_bf16_with_inplace_add ,
32463252 swap_inputs = swap_inputs ,
32473253 ),
3254+ is_fp8 = is_fp8 ,
32483255 ),
32493256 PostOpAttr (
32503257 "sum" , 1.0 , "relu" , [], ""
32513258 ): generate_pattern_with_output_quant (
32523259 generate_pattern_with_unary (
32533260 generate_pattern_with_binary (
32543261 aten .add .Tensor ,
3255- get_qconv_pt2e_pattern (users = 1 ),
3262+ get_qconv_pt2e_pattern (x_scale_zp_are_tensors , 1 ),
32563263 dequantize_accum_pattern ,
32573264 int8_mixed_bf16_with_inplace_add ,
32583265 swap_inputs = swap_inputs ,
32593266 ),
32603267 aten .relu .default ,
32613268 ),
3269+ is_fp8 = is_fp8 ,
32623270 ),
32633271 }
32643272 )
@@ -3267,7 +3275,7 @@ def _register_qconv_binary_fusion():
32673275 _register_qconv_post_op_fusion_pass (
32683276 patterns ,
32693277 3 , # pass_number
3270- torch . ops . onednn . qconv2d_pointwise . binary , # computation_op
3278+ qconv_binary_op , # computation_op
32713279 binary_unary_attr , # binary_unary_attr
32723280 )
32733281
@@ -3279,7 +3287,7 @@ def _register_qconv_binary_fusion():
32793287 PostOpAttr ("sum" , 1.0 , "relu" , [], "" ): generate_pattern_with_unary (
32803288 generate_pattern_with_binary (
32813289 aten .add .Tensor ,
3282- get_qconv_pt2e_pattern (users = 1 ),
3290+ get_qconv_pt2e_pattern (x_scale_zp_are_tensors , 1 ),
32833291 KeywordArg ("accum_after_dequant" ),
32843292 int8_mixed_bf16_with_inplace_add ,
32853293 swap_inputs = swap_inputs ,
@@ -3297,14 +3305,14 @@ def _register_qconv_binary_fusion():
32973305 _register_qconv_post_op_fusion_pass (
32983306 patterns ,
32993307 3 , # pass_number
3300- torch . ops . onednn . qconv2d_pointwise . binary , # computation_op
3308+ qconv_binary_op , # computation_op
33013309 binary_unary_attr , # binary_unary_attr
33023310 )
33033311 else :
33043312 _register_qconv_post_op_fusion_pass (
33053313 patterns ,
33063314 4 , # pass_number
3307- torch . ops . onednn . qconv2d_pointwise . binary , # computation_op
3315+ qconv_binary_op , # computation_op
33083316 binary_unary_attr , # binary_unary_attr
33093317 )
33103318
@@ -3317,7 +3325,7 @@ def _register_qconv_binary_fusion():
33173325 "sum" , 1.0 , "none" , [], ""
33183326 ): generate_pattern_with_binary (
33193327 aten .add .Tensor ,
3320- get_qconv_pt2e_pattern (users = 1 ),
3328+ get_qconv_pt2e_pattern (x_scale_zp_are_tensors , 1 ),
33213329 KeywordArg ("accum_after_dequant" ),
33223330 int8_mixed_bf16_with_inplace_add ,
33233331 swap_inputs = swap_inputs ,
@@ -3332,7 +3340,7 @@ def _register_qconv_binary_fusion():
33323340 _register_qconv_post_op_fusion_pass (
33333341 patterns ,
33343342 4 if int8_mixed_bf16_with_inplace_add else 5 , # pass_number
3335- torch . ops . onednn . qconv2d_pointwise . binary , # computation_op
3343+ qconv_binary_op , # computation_op
33363344 binary_unary_attr , # binary_unary_attr
33373345 )
33383346
0 commit comments