@@ -108,6 +108,19 @@ def _bias_add(x, b, data_format):
108108
109109def batch_normalization (x , mean , variance , offset , scale , variance_epsilon , data_format , name = None ):
110110 """Data Format aware version of tf.nn.batch_normalization."""
111+ if data_format == 'channels_last' :
112+ mean = tf .reshape (mean , [1 ] * (len (x .shape ) - 1 ) + [- 1 ])
113+ variance = tf .reshape (variance , [1 ] * (len (x .shape ) - 1 ) + [- 1 ])
114+ offset = tf .reshape (offset , [1 ] * (len (x .shape ) - 1 ) + [- 1 ])
115+ scale = tf .reshape (scale , [1 ] * (len (x .shape ) - 1 ) + [- 1 ])
116+ elif data_format == 'channels_first' :
117+ mean = tf .reshape (mean , [1 ] + [- 1 ] + [1 ] * (len (x .shape ) - 2 ))
118+ variance = tf .reshape (variance , [1 ] + [- 1 ] + [1 ] * (len (x .shape ) - 2 ))
119+ offset = tf .reshape (offset , [1 ] + [- 1 ] + [1 ] * (len (x .shape ) - 2 ))
120+ scale = tf .reshape (scale , [1 ] + [- 1 ] + [1 ] * (len (x .shape ) - 2 ))
121+ else :
122+ raise ValueError ('invalid data_format: %s' % data_format )
123+
111124 with ops .name_scope (name , 'batchnorm' , [x , mean , variance , scale , offset ]):
112125 inv = math_ops .rsqrt (variance + variance_epsilon )
113126 if scale is not None :
@@ -204,13 +217,10 @@ def __init__(
204217 self .moving_var_init = moving_var_init
205218 self .num_features = num_features
206219
220+ self .channel_axis = - 1 if data_format == 'channels_last' else 1
221+ self .axes = None
222+
207223 if num_features is not None :
208- if not isinstance (self , BatchNorm1d ) and not isinstance (self , BatchNorm2d ) and not isinstance (self ,
209- BatchNorm3d ):
210- raise ValueError (
211- "Please use BatchNorm1d or BatchNorm2d or BatchNorm3d instead of BatchNorm "
212- "if you want to specify 'num_features'."
213- )
214224 self .build (None )
215225 self ._built = True
216226
@@ -233,21 +243,23 @@ def __repr__(self):
233243
234244 def _get_param_shape (self , inputs_shape ):
235245 if self .data_format == 'channels_last' :
236- axis = len ( inputs_shape ) - 1
246+ axis = - 1
237247 elif self .data_format == 'channels_first' :
238248 axis = 1
239249 else :
240250 raise ValueError ('data_format should be either %s or %s' % ('channels_last' , 'channels_first' ))
241251
242252 channels = inputs_shape [axis ]
243- params_shape = [1 ] * len (inputs_shape )
244- params_shape [axis ] = channels
253+ params_shape = [channels ]
245254
246- axes = [i for i in range (len (inputs_shape )) if i != axis ]
247- return params_shape , axes
255+ return params_shape
256+
257+ def _check_input_shape (self , inputs ):
258+ if inputs .ndim <= 1 :
259+ raise ValueError ('expected input at least 2D, but got {}D input' .format (inputs .ndim ))
248260
249261 def build (self , inputs_shape ):
250- params_shape , self .axes = self ._get_param_shape (inputs_shape )
262+ params_shape = [ self .num_features ] if self . num_features is not None else self ._get_param_shape (inputs_shape )
251263
252264 self .beta , self .gamma = None , None
253265 if self .beta_init :
@@ -264,7 +276,12 @@ def build(self, inputs_shape):
264276 )
265277
266278 def forward (self , inputs ):
267- mean , var = tf .nn .moments (inputs , self .axes , keepdims = True )
279+ self ._check_input_shape (inputs )
280+
281+ if self .axes is None :
282+ self .axes = [i for i in range (len (inputs .shape )) if i != self .channel_axis ]
283+
284+ mean , var = tf .nn .moments (inputs , self .axes , keepdims = False )
268285 if self .is_train :
269286 # update moving_mean and moving_var
270287 self .moving_mean = moving_averages .assign_moving_average (
@@ -282,8 +299,8 @@ def forward(self, inputs):
282299
283300
284301class BatchNorm1d (BatchNorm ):
285- """The :class:`BatchNorm1d` applies Batch Normalization over 3D input (a mini-batch of 1D
286- inputs with additional channel dimension), of shape (N, L, C) or (N, C, L).
302+ """The :class:`BatchNorm1d` applies Batch Normalization over 2D/ 3D input (a mini-batch of 1D
303+ inputs (optional) with additional channel dimension), of shape (N, C) or (N, L, C) or (N, C, L).
287304 See more details in :class:`BatchNorm`.
288305
289306 Examples
@@ -299,23 +316,9 @@ class BatchNorm1d(BatchNorm):
299316
300317 """
301318
302- def _get_param_shape (self , inputs_shape ):
303- if self .data_format == 'channels_last' :
304- axis = 2
305- elif self .data_format == 'channels_first' :
306- axis = 1
307- else :
308- raise ValueError ('data_format should be either %s or %s' % ('channels_last' , 'channels_first' ))
309-
310- if self .num_features is None :
311- channels = inputs_shape [axis ]
312- else :
313- channels = self .num_features
314- params_shape = [1 ] * 3
315- params_shape [axis ] = channels
316-
317- axes = [i for i in range (3 ) if i != axis ]
318- return params_shape , axes
319+ def _check_input_shape (self , inputs ):
320+ if inputs .ndim != 2 and inputs .ndim != 3 :
321+ raise ValueError ('expected input to be 2D or 3D, but got {}D input' .format (inputs .ndim ))
319322
320323
321324class BatchNorm2d (BatchNorm ):
@@ -336,23 +339,9 @@ class BatchNorm2d(BatchNorm):
336339
337340 """
338341
339- def _get_param_shape (self , inputs_shape ):
340- if self .data_format == 'channels_last' :
341- axis = 3
342- elif self .data_format == 'channels_first' :
343- axis = 1
344- else :
345- raise ValueError ('data_format should be either %s or %s' % ('channels_last' , 'channels_first' ))
346-
347- if self .num_features is None :
348- channels = inputs_shape [axis ]
349- else :
350- channels = self .num_features
351- params_shape = [1 ] * 4
352- params_shape [axis ] = channels
353-
354- axes = [i for i in range (4 ) if i != axis ]
355- return params_shape , axes
342+ def _check_input_shape (self , inputs ):
343+ if inputs .ndim != 4 :
344+ raise ValueError ('expected input to be 4D, but got {}D input' .format (inputs .ndim ))
356345
357346
358347class BatchNorm3d (BatchNorm ):
@@ -373,23 +362,9 @@ class BatchNorm3d(BatchNorm):
373362
374363 """
375364
376- def _get_param_shape (self , inputs_shape ):
377- if self .data_format == 'channels_last' :
378- axis = 4
379- elif self .data_format == 'channels_first' :
380- axis = 1
381- else :
382- raise ValueError ('data_format should be either %s or %s' % ('channels_last' , 'channels_first' ))
383-
384- if self .num_features is None :
385- channels = inputs_shape [axis ]
386- else :
387- channels = self .num_features
388- params_shape = [1 ] * 5
389- params_shape [axis ] = channels
390-
391- axes = [i for i in range (5 ) if i != axis ]
392- return params_shape , axes
365+ def _check_input_shape (self , inputs ):
366+ if inputs .ndim != 5 :
367+ raise ValueError ('expected input to be 5D, but got {}D input' .format (inputs .ndim ))
393368
394369
395370class InstanceNorm (Layer ):
0 commit comments