@@ -108,6 +108,19 @@ def _bias_add(x, b, data_format):
108108
109109def batch_normalization (x , mean , variance , offset , scale , variance_epsilon , data_format , name = None ):
110110 """Data Format aware version of tf.nn.batch_normalization."""
111+ if data_format == 'channels_last' :
112+ mean = tf .reshape (mean , [1 ] * (len (x .shape ) - 1 ) + [- 1 ])
113+ variance = tf .reshape (variance , [1 ] * (len (x .shape ) - 1 ) + [- 1 ])
114+ offset = tf .reshape (offset , [1 ] * (len (x .shape ) - 1 ) + [- 1 ])
115+ scale = tf .reshape (scale , [1 ] * (len (x .shape ) - 1 ) + [- 1 ])
116+ elif data_format == 'channels_first' :
117+ mean = tf .reshape (mean , [1 ] + [- 1 ] + [1 ] * (len (x .shape ) - 2 ))
118+ variance = tf .reshape (variance , [1 ] + [- 1 ] + [1 ] * (len (x .shape ) - 2 ))
119+ offset = tf .reshape (offset , [1 ] + [- 1 ] + [1 ] * (len (x .shape ) - 2 ))
120+ scale = tf .reshape (scale , [1 ] + [- 1 ] + [1 ] * (len (x .shape ) - 2 ))
121+ else :
122+ raise ValueError ('invalid data_format: %s' % data_format )
123+
111124 with ops .name_scope (name , 'batchnorm' , [x , mean , variance , scale , offset ]):
112125 inv = math_ops .rsqrt (variance + variance_epsilon )
113126 if scale is not None :
@@ -204,6 +217,9 @@ def __init__(
204217 self .moving_var_init = moving_var_init
205218 self .num_features = num_features
206219
220+ self .channel_axis = - 1 if data_format == 'channels_last' else 1
221+ self .axes = None
222+
207223 if num_features is not None :
208224 if not isinstance (self , BatchNorm1d ) and not isinstance (self , BatchNorm2d ) and not isinstance (self ,
209225 BatchNorm3d ):
@@ -233,21 +249,23 @@ def __repr__(self):
233249
234250 def _get_param_shape (self , inputs_shape ):
235251 if self .data_format == 'channels_last' :
236- axis = len ( inputs_shape ) - 1
252+ axis = - 1
237253 elif self .data_format == 'channels_first' :
238254 axis = 1
239255 else :
240256 raise ValueError ('data_format should be either %s or %s' % ('channels_last' , 'channels_first' ))
241257
242258 channels = inputs_shape [axis ]
243- params_shape = [1 ] * len (inputs_shape )
244- params_shape [axis ] = channels
259+ params_shape = [channels ]
245260
246- axes = [i for i in range (len (inputs_shape )) if i != axis ]
247- return params_shape , axes
261+ return params_shape
262+
263+ def _check_input_shape (self , inputs ):
264+ if inputs .ndim <= 1 :
265+ raise ValueError ('expected input at least 2D, but got {}D input' .format (inputs .ndim ))
248266
249267 def build (self , inputs_shape ):
250- params_shape , self .axes = self ._get_param_shape (inputs_shape )
268+ params_shape = [ self .num_features ] if self . num_features is not None else self ._get_param_shape (inputs_shape )
251269
252270 self .beta , self .gamma = None , None
253271 if self .beta_init :
@@ -264,7 +282,12 @@ def build(self, inputs_shape):
264282 )
265283
266284 def forward (self , inputs ):
267- mean , var = tf .nn .moments (inputs , self .axes , keepdims = True )
285+ self ._check_input_shape (inputs )
286+
287+ if self .axes is None :
288+ self .axes = [i for i in range (len (inputs .shape )) if i != self .channel_axis ]
289+
290+ mean , var = tf .nn .moments (inputs , self .axes , keepdims = False )
268291 if self .is_train :
269292 # update moving_mean and moving_var
270293 self .moving_mean = moving_averages .assign_moving_average (
@@ -282,8 +305,8 @@ def forward(self, inputs):
282305
283306
284307class BatchNorm1d (BatchNorm ):
285- """The :class:`BatchNorm1d` applies Batch Normalization over 3D input (a mini-batch of 1D
286- inputs with additional channel dimension), of shape (N, L, C) or (N, C, L).
308+ """The :class:`BatchNorm1d` applies Batch Normalization over 2D/ 3D input (a mini-batch of 1D
309+ inputs (optional) with additional channel dimension), of shape (N, C) or (N, L, C) or (N, C, L).
287310 See more details in :class:`BatchNorm`.
288311
289312 Examples
@@ -298,24 +321,9 @@ class BatchNorm1d(BatchNorm):
298321 >>> bn = tl.layers.BatchNorm1d(num_features=32)
299322
300323 """
301-
302- def _get_param_shape (self , inputs_shape ):
303- if self .data_format == 'channels_last' :
304- axis = 2
305- elif self .data_format == 'channels_first' :
306- axis = 1
307- else :
308- raise ValueError ('data_format should be either %s or %s' % ('channels_last' , 'channels_first' ))
309-
310- if self .num_features is None :
311- channels = inputs_shape [axis ]
312- else :
313- channels = self .num_features
314- params_shape = [1 ] * 3
315- params_shape [axis ] = channels
316-
317- axes = [i for i in range (3 ) if i != axis ]
318- return params_shape , axes
324+ def _check_input_shape (self , inputs ):
325+ if inputs .ndim != 2 and inputs .ndim != 3 :
326+ raise ValueError ('expected input to be 2D or 3D, but got {}D input' .format (inputs .ndim ))
319327
320328
321329class BatchNorm2d (BatchNorm ):
@@ -335,24 +343,9 @@ class BatchNorm2d(BatchNorm):
335343 >>> bn = tl.layers.BatchNorm2d(num_features=32)
336344
337345 """
338-
339- def _get_param_shape (self , inputs_shape ):
340- if self .data_format == 'channels_last' :
341- axis = 3
342- elif self .data_format == 'channels_first' :
343- axis = 1
344- else :
345- raise ValueError ('data_format should be either %s or %s' % ('channels_last' , 'channels_first' ))
346-
347- if self .num_features is None :
348- channels = inputs_shape [axis ]
349- else :
350- channels = self .num_features
351- params_shape = [1 ] * 4
352- params_shape [axis ] = channels
353-
354- axes = [i for i in range (4 ) if i != axis ]
355- return params_shape , axes
346+ def _check_input_shape (self , inputs ):
347+ if inputs .ndim != 4 :
348+ raise ValueError ('expected input to be 4D, but got {}D input' .format (inputs .ndim ))
356349
357350
358351class BatchNorm3d (BatchNorm ):
@@ -372,24 +365,9 @@ class BatchNorm3d(BatchNorm):
372365 >>> bn = tl.layers.BatchNorm3d(num_features=32)
373366
374367 """
375-
376- def _get_param_shape (self , inputs_shape ):
377- if self .data_format == 'channels_last' :
378- axis = 4
379- elif self .data_format == 'channels_first' :
380- axis = 1
381- else :
382- raise ValueError ('data_format should be either %s or %s' % ('channels_last' , 'channels_first' ))
383-
384- if self .num_features is None :
385- channels = inputs_shape [axis ]
386- else :
387- channels = self .num_features
388- params_shape = [1 ] * 5
389- params_shape [axis ] = channels
390-
391- axes = [i for i in range (5 ) if i != axis ]
392- return params_shape , axes
368+ def _check_input_shape (self , inputs ):
369+ if inputs .ndim != 5 :
370+ raise ValueError ('expected input to be 5D, but got {}D input' .format (inputs .ndim ))
393371
394372
395373class InstanceNorm (Layer ):
0 commit comments