@@ -509,6 +509,14 @@ def __init__(self, dim=-1):
509509 def forward (self , x ):
510510 return self .softmax (x )
511511
512+ class AtenBatchNormRepalce (nn .Module ):
513+ def __init__ (self ):
514+ super (AtenBatchNormRepalce , self ).__init__ ()
515+ self .bn = torch .nn .BatchNorm2d (10 )
516+
517+ def forward (self , x ):
518+ return self .bn (x )
519+
512520class AddLayerNorm (torch .nn .Module ):
513521 def __init__ (self , dim = 32 ):
514522 super (AddLayerNorm , self ).__init__ ()
@@ -925,35 +933,35 @@ def test_output_conv_bn_2d(self):
925933 ConvBatchNorm_Fixed (2 , 3 , 32 , kernel_size = 3 , stride = 1 ),
926934 torch .randn (32 , 3 , 64 , 64 ),
927935 kind_in_graph = "ipex_prepack::convolution_run" ,
928- kind_not_in_graph = "aten ::batch_norm" ,
936+ kind_not_in_graph = "ipex ::batch_norm" ,
929937 levels = ['O1' ])
930938 self ._test_output_bf16 (
931939 ConvBatchNorm_Fixed (2 , 3 , 32 , kernel_size = 3 , stride = 1 ),
932940 torch .randn (32 , 3 , 64 , 64 ),
933941 kind_in_graph = "ipex_prepack::convolution_run" ,
934- kind_not_in_graph = "aten ::batch_norm" ,
942+ kind_not_in_graph = "ipex ::batch_norm" ,
935943 prec = 0.02 ,
936944 levels = ['O1' ])
937945
938946 def test_output_bn_conv_2d (self ):
939947 self ._test_output (
940948 BatchNormConv_Fixed (2 , 3 , 32 , kernel_size = 3 , stride = 1 ),
941949 torch .randn (32 , 3 , 64 , 64 ),
942- kind_in_graph = "aten ::batch_norm" ,
950+ kind_in_graph = "ipex ::batch_norm" ,
943951 kind_not_in_graph = None )
944952
945953 def test_output_bn_conv_bn (self ):
946954 self ._test_output (
947955 BatchNorm_Conv_BatchNorm (2 , 3 , 32 , kernel_size = 3 , stride = 1 ),
948956 torch .randn (32 , 3 , 64 , 64 ),
949- kind_in_graph = "aten ::batch_norm" ,
957+ kind_in_graph = "ipex ::batch_norm" ,
950958 kind_not_in_graph = None )
951959
952960 def test_output_conv_reshape_bn_2d (self ):
953961 self ._test_output (
954962 ConvReshapeBatchNorm (2 , 3 , 32 , (64 , 16 , 62 , 62 ), kernel_size = 3 , stride = 1 ),
955963 torch .randn (32 , 3 , 64 , 64 ),
956- kind_in_graph = "aten ::batch_norm" ,
964+ kind_in_graph = "ipex ::batch_norm" ,
957965 kind_not_in_graph = None )
958966
959967 def test_output_conv_conv_concate (self ):
@@ -994,7 +1002,7 @@ def test_output_conv_bn_3d(self):
9941002 ConvBatchNorm_Fixed (3 , 3 , 32 , kernel_size = 3 , stride = 1 ),
9951003 torch .randn (32 , 3 , 32 , 32 , 32 ),
9961004 kind_in_graph = "aten::conv3d" ,
997- kind_not_in_graph = "aten ::batch_norm" )
1005+ kind_not_in_graph = "ipex ::batch_norm" )
9981006
9991007 def test_output_conv_relu_2d (self ):
10001008 self ._test_output (
@@ -1061,25 +1069,25 @@ def test_output_cascaded_conv_bn_sum_relu_2d(self):
10611069 CascadedConvBnSumRelu (2 , 3 , 64 , 32 , kernel_size = 3 , stride = 1 ),
10621070 torch .rand (32 , 3 , 64 , 64 ),
10631071 kind_in_graph = "ipex_prepack::convolution_add_relu_run" ,
1064- kind_not_in_graph = "aten ::batch_norm" )
1072+ kind_not_in_graph = "ipex ::batch_norm" )
10651073 self ._test_output_bf16 (
10661074 CascadedConvBnSumRelu (2 , 3 , 64 , 32 , kernel_size = 3 , stride = 1 ),
10671075 torch .rand (32 , 3 , 64 , 64 ),
10681076 kind_in_graph = "ipex_prepack::convolution_add_relu_run" ,
1069- kind_not_in_graph = "aten ::batch_norm" ,
1077+ kind_not_in_graph = "ipex ::batch_norm" ,
10701078 prec = 0.02 )
10711079
10721080 def test_output_cascaded_conv_bn_sum_relu_3d (self ):
10731081 self ._test_output (
10741082 CascadedConvBnSumRelu (3 , 3 , 64 , 32 , kernel_size = 3 , stride = 1 ),
10751083 torch .rand (32 , 3 , 32 , 32 , 32 ),
10761084 kind_in_graph = "ipex::conv3d_sum_relu" ,
1077- kind_not_in_graph = "aten ::batch_norm" )
1085+ kind_not_in_graph = "ipex ::batch_norm" )
10781086 self ._test_output_bf16 (
10791087 CascadedConvBnSumRelu (3 , 3 , 64 , 32 , kernel_size = 3 , stride = 1 ),
10801088 torch .rand (32 , 3 , 32 , 32 , 32 ),
10811089 kind_in_graph = "ipex::conv3d_sum_relu" ,
1082- kind_not_in_graph = "aten ::batch_norm" ,
1090+ kind_not_in_graph = "ipex ::batch_norm" ,
10831091 prec = 0.02 )
10841092
10851093 def test_output_conv_transpose2d (self ):
@@ -1346,6 +1354,17 @@ def test_ipex_softmax(self):
13461354 kind_in_graph = "ipex::softmax" ,
13471355 prec = 5e-3 )
13481356
1357+ def test_ipex_batch_norm (self ):
1358+ self ._test_output (
1359+ AtenBatchNormRepalce (),
1360+ torch .rand (10 , 10 , 4 , 4 ),
1361+ kind_in_graph = "ipex::batch_norm" )
1362+ self ._test_output_bf16 (
1363+ AtenBatchNormRepalce (),
1364+ torch .rand (10 , 10 , 4 , 4 , dtype = torch .bfloat16 ),
1365+ kind_in_graph = "ipex::batch_norm" ,
1366+ prec = 5e-3 )
1367+
13491368 def test_restore_inplace (self ):
13501369 class M (nn .Module ):
13511370 def __init__ (self , eltwise_fn , params_dict = {}):
0 commit comments