@@ -18,11 +18,13 @@ class Laye_BatchNorm_Test(CustomTestCase):
1818 @classmethod
1919 def setUpClass (cls ):
2020
21+ x_0_input_shape = [None , 10 ]
2122 x_1_input_shape = [None , 100 , 1 ]
2223 x_2_input_shape = [None , 100 , 100 , 3 ]
2324 x_3_input_shape = [None , 100 , 100 , 100 , 3 ]
2425 batchsize = 2
2526
27+ cls .x0 = tf .random .normal ([batchsize ] + x_0_input_shape [1 :])
2628 cls .x1 = tf .random .normal ([batchsize ] + x_1_input_shape [1 :])
2729 cls .x2 = tf .random .normal ([batchsize ] + x_2_input_shape [1 :])
2830 cls .x3 = tf .random .normal ([batchsize ] + x_3_input_shape [1 :])
@@ -36,16 +38,58 @@ def setUpClass(cls):
3638
3739 ni_2 = Input (x_2_input_shape , name = 'test_ni2' )
3840 nn_2 = Conv2d (n_filter = 32 , filter_size = (3 , 3 ), strides = (2 , 2 ), name = 'test_conv2d' )(ni_2 )
39- n2_b = BatchNorm2d (name = 'test_bn2d' )(nn_2 )
41+ n2_b = BatchNorm (name = 'test_bn2d' )(nn_2 )
4042 cls .n2_b = n2_b
4143 cls .base_2d = Model (inputs = ni_2 , outputs = n2_b , name = 'test_base_2d' )
4244
4345 ni_3 = Input (x_3_input_shape , name = 'test_ni2' )
4446 nn_3 = Conv3d (n_filter = 32 , filter_size = (3 , 3 , 3 ), strides = (2 , 2 , 2 ), name = 'test_conv3d' )(ni_3 )
45- n3_b = BatchNorm3d (name = 'test_bn3d' )(nn_3 )
47+ n3_b = BatchNorm (name = 'test_bn3d' )(nn_3 )
4648 cls .n3_b = n3_b
4749 cls .base_3d = Model (inputs = ni_3 , outputs = n3_b , name = 'test_base_3d' )
4850
51+ class bn_0d_model (Model ):
52+
53+ def __init__ (self ):
54+ super (bn_0d_model , self ).__init__ ()
55+ self .fc = Dense (32 , in_channels = 10 )
56+ self .bn = BatchNorm (num_features = 32 , name = 'test_bn1d' )
57+
58+ def forward (self , x ):
59+ x = self .bn (self .fc (x ))
60+ return x
61+
62+ dynamic_base = bn_0d_model ()
63+ cls .n0_b = dynamic_base (cls .x0 , is_train = True )
64+
65+ ## 0D ========================================================================
66+
67+ nin_0 = Input (x_0_input_shape , name = 'test_in1' )
68+
69+ n0 = Dense (32 )(nin_0 )
70+ n0 = BatchNorm1d (name = 'test_bn0d' )(n0 )
71+
72+ cls .n0 = n0
73+
74+ cls .static_0d = Model (inputs = nin_0 , outputs = n0 )
75+
76+ class bn_0d_model (Model ):
77+
78+ def __init__ (self ):
79+ super (bn_0d_model , self ).__init__ (name = 'test_bn_0d_model' )
80+ self .fc = Dense (32 , in_channels = 10 )
81+ self .bn = BatchNorm1d (num_features = 32 , name = 'test_bn1d' )
82+
83+ def forward (self , x ):
84+ x = self .bn (self .fc (x ))
85+ return x
86+
87+ cls .dynamic_0d = bn_0d_model ()
88+
89+ print ("Printing BatchNorm0d" )
90+ print (cls .static_0d )
91+ print (cls .dynamic_0d )
92+
4993 ## 1D ========================================================================
5094
5195 nin_1 = Input (x_1_input_shape , name = 'test_in1' )
@@ -147,6 +191,14 @@ def test_BatchNorm(self):
147191 self .assertEqual (self .n3_b .shape [1 :], (50 , 50 , 50 , 32 ))
148192 out = self .base_3d (self .x3 , is_train = True )
149193
194+ self .assertEqual (self .n0_b .shape [1 :], (32 ))
195+ print ("test_BatchNorm OK" )
196+
197+ def test_BatchNorm0d (self ):
198+ self .assertEqual (self .n0 .shape [1 :], (32 ))
199+ out = self .static_0d (self .x0 , is_train = True )
200+ out = self .dynamic_0d (self .x0 , is_train = True )
201+
150202 def test_BatchNorm1d (self ):
151203 self .assertEqual (self .n1 .shape [1 :], (50 , 32 ))
152204 out = self .static_1d (self .x1 , is_train = True )
@@ -189,6 +241,25 @@ def test_exception(self):
189241 self .assertIsInstance (e , ValueError )
190242 print (e )
191243
244+ def test_input_shape (self ):
245+ try :
246+ bn = BatchNorm1d (num_features = 32 )
247+ out = bn (self .x2 )
248+ except Exception as e :
249+ self .assertIsInstance (e , ValueError )
250+ print (e )
251+ try :
252+ bn = BatchNorm2d (num_features = 32 )
253+ out = bn (self .x3 )
254+ except Exception as e :
255+ self .assertIsInstance (e , ValueError )
256+ print (e )
257+ try :
258+ bn = BatchNorm3d (num_features = 32 )
259+ out = bn (self .x1 )
260+ except Exception as e :
261+ self .assertIsInstance (e , ValueError )
262+ print (e )
192263
193264if __name__ == '__main__' :
194265
0 commit comments