Skip to content

Commit be50645

Browse files
ayp-googlemn-robot
authored andcommitted
Internal change
PiperOrigin-RevId: 300001033
1 parent 04fa416 commit be50645

File tree

2 files changed

+21
-155
lines changed

2 files changed

+21
-155
lines changed

morph_net/framework/grouping_op_handler_test.py

Lines changed: 13 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,6 @@ def setUp(self):
6060
self.beta_op = g.get_operation_by_name('conv1/BatchNorm/beta/read')
6161
self.beta_op_slice = orm.OpSlice(self.beta_op, orm.Slice(0, 5))
6262

63-
self.mean_op = g.get_operation_by_name(
64-
'conv1/BatchNorm/AssignMovingAvg/sub_1')
65-
self.mean_op_slice = orm.OpSlice(self.mean_op, orm.Slice(0, 5))
66-
67-
self.std_op = g.get_operation_by_name(
68-
'conv1/BatchNorm/AssignMovingAvg_1/sub_1')
69-
self.std_op_slice = orm.OpSlice(self.std_op, orm.Slice(0, 5))
70-
7163
# Create mock OpRegularizerManager with custom mapping of OpSlice and
7264
# OpGroup.
7365
self.mock_op_reg_manager = mock.create_autospec(orm.OpRegularizerManager)
@@ -78,8 +70,6 @@ def setUp(self):
7870
self.relu_op: [self.relu_op_slice],
7971
self.gamma_op: [self.gamma_op_slice],
8072
self.beta_op: [self.beta_op_slice],
81-
self.mean_op: [self.mean_op_slice],
82-
self.std_op: [self.std_op_slice],
8373
}
8474
def get_op_slices(op):
8575
return self.op_slice_dict.get(op)
@@ -92,7 +82,7 @@ def get_op_group(op_slice):
9282
self.mock_op_reg_manager.is_source_op.return_value = False
9383
self.mock_op_reg_manager.ops = [
9484
self.batch_norm_op, self.conv_op, self.relu_op, self.gamma_op,
95-
self.beta_op, self.mean_op, self.std_op]
85+
self.beta_op]
9686

9787
def testAssignGrouping_NoNeighborGroups(self):
9888
# No ops have groups.
@@ -109,39 +99,30 @@ def testAssignGrouping_NoNeighborGroups(self):
10999
mock.call(self.gamma_op),
110100
mock.call(self.beta_op),
111101
mock.call(self.relu_op),
112-
mock.call(self.mean_op),
113-
mock.call(self.std_op),
114102
# Initial slice data.
115103
mock.call(self.batch_norm_op),
116104
mock.call(self.conv_op),
117105
mock.call(self.gamma_op),
118106
mock.call(self.beta_op),
119107
mock.call(self.relu_op),
120-
mock.call(self.mean_op),
121-
mock.call(self.std_op),
122108
# Reslicing.
123109
mock.call(self.conv_op),
124110
mock.call(self.gamma_op),
125111
mock.call(self.beta_op),
126112
mock.call(self.batch_norm_op),
127113
mock.call(self.relu_op),
128-
mock.call(self.mean_op),
129-
mock.call(self.std_op),
130114
# Refreshing slice data.
131115
mock.call(self.conv_op),
132116
mock.call(self.gamma_op),
133117
mock.call(self.beta_op),
134-
mock.call(self.relu_op),
135-
mock.call(self.mean_op),
136-
mock.call(self.std_op)])
118+
mock.call(self.relu_op)])
137119

138120
# Verify manager does not group.
139121
self.mock_op_reg_manager.group_op_slices.assert_not_called()
140122

141123
# Verify manager processes grouping for Conv2D, ReLU, and batch norm ops.
142124
self.mock_op_reg_manager.process_ops.assert_called_once_with(
143-
[self.relu_op, self.mean_op, self.std_op, self.conv_op, self.gamma_op,
144-
self.beta_op])
125+
[self.relu_op, self.conv_op, self.gamma_op, self.beta_op])
145126
self.mock_op_reg_manager.process_ops_last.assert_called_once_with(
146127
[self.batch_norm_op])
147128

@@ -167,43 +148,35 @@ def testAssignGrouping_AllInputsGrouped(self):
167148
mock.call(self.gamma_op),
168149
mock.call(self.beta_op),
169150
mock.call(self.relu_op),
170-
mock.call(self.mean_op),
171-
mock.call(self.std_op),
172151
# Initial slice data.
173152
mock.call(self.batch_norm_op),
174153
mock.call(self.conv_op),
175154
mock.call(self.gamma_op),
176155
mock.call(self.beta_op),
177156
mock.call(self.relu_op),
178-
mock.call(self.mean_op),
179-
mock.call(self.std_op),
180157
# Reslicing.
181158
mock.call(self.conv_op),
182159
mock.call(self.gamma_op),
183160
mock.call(self.beta_op),
184161
mock.call(self.batch_norm_op),
185162
mock.call(self.relu_op),
186-
mock.call(self.mean_op),
187-
mock.call(self.std_op),
188163
# Refreshing slice data.
189164
mock.call(self.conv_op),
190165
mock.call(self.gamma_op),
191166
mock.call(self.beta_op),
192167
mock.call(self.relu_op),
193-
mock.call(self.mean_op),
194-
mock.call(self.std_op),
195168
# Group batch norm op.
196169
mock.call(self.batch_norm_op)])
197170

198171
# Verify manager groups batch norm with input ops.
199-
self.mock_op_reg_manager.group_op_slices.assert_called_once_with(
200-
[self.batch_norm_op_slice, self.conv_op_slice, self.gamma_op_slice,
201-
self.beta_op_slice])
172+
self.mock_op_reg_manager.group_op_slices.assert_has_calls(
173+
[mock.call([self.batch_norm_op_slice, self.relu_op_slice]),
174+
mock.call([self.batch_norm_op_slice, self.conv_op_slice,
175+
self.gamma_op_slice, self.beta_op_slice])])
202176

203177
# Verify manager processes grouping for mean_op and std_op which do not have
204178
# groups.
205-
self.mock_op_reg_manager.process_ops.assert_called_once_with(
206-
[self.mean_op, self.std_op])
179+
self.mock_op_reg_manager.process_ops.assert_not_called()
207180
self.mock_op_reg_manager.process_ops_last.assert_not_called()
208181

209182
def testAssignGrouping_AllOutputsGrouped(self):
@@ -213,8 +186,6 @@ def testAssignGrouping_AllOutputsGrouped(self):
213186
self.conv_op_slice: self.conv_op_group,
214187
self.relu_op_slice: self.relu_op_group,
215188
self.gamma_op_slice: self.conv_op_group,
216-
self.mean_op_slice: self.relu_op_group,
217-
self.std_op_slice: self.relu_op_group,
218189
}
219190

220191
# Call handler to assign grouping.
@@ -228,31 +199,23 @@ def testAssignGrouping_AllOutputsGrouped(self):
228199
mock.call(self.gamma_op),
229200
mock.call(self.beta_op),
230201
mock.call(self.relu_op),
231-
mock.call(self.mean_op),
232-
mock.call(self.std_op),
233202
# Initial slice data.
234203
mock.call(self.batch_norm_op),
235204
mock.call(self.conv_op),
236205
mock.call(self.gamma_op),
237206
mock.call(self.beta_op),
238207
mock.call(self.relu_op),
239-
mock.call(self.mean_op),
240-
mock.call(self.std_op),
241208
# Reslicing.
242209
mock.call(self.conv_op),
243210
mock.call(self.gamma_op),
244211
mock.call(self.beta_op),
245212
mock.call(self.batch_norm_op),
246213
mock.call(self.relu_op),
247-
mock.call(self.mean_op),
248-
mock.call(self.std_op),
249214
# Refreshing slice data.
250215
mock.call(self.conv_op),
251216
mock.call(self.gamma_op),
252217
mock.call(self.beta_op),
253-
mock.call(self.relu_op),
254-
mock.call(self.mean_op),
255-
mock.call(self.std_op)])
218+
mock.call(self.relu_op)])
256219

257220
# Verify manager does not group.
258221
self.mock_op_reg_manager.group_op_slices.assert_not_called()
@@ -271,8 +234,6 @@ def testAssignGrouping_AllNeighborsGrouped(self):
271234
self.relu_op_slice: self.relu_op_group,
272235
self.gamma_op_slice: self.conv_op_group,
273236
self.beta_op_slice: self.conv_op_group,
274-
self.mean_op_slice: self.relu_op_group,
275-
self.std_op_slice: self.relu_op_group,
276237
}
277238

278239
# Call handler to assign grouping.
@@ -286,38 +247,29 @@ def testAssignGrouping_AllNeighborsGrouped(self):
286247
mock.call(self.gamma_op),
287248
mock.call(self.beta_op),
288249
mock.call(self.relu_op),
289-
mock.call(self.mean_op),
290-
mock.call(self.std_op),
291250
# Initial slice data.
292251
mock.call(self.batch_norm_op),
293252
mock.call(self.conv_op),
294253
mock.call(self.gamma_op),
295254
mock.call(self.beta_op),
296255
mock.call(self.relu_op),
297-
mock.call(self.mean_op),
298-
mock.call(self.std_op),
299256
# Reslicing.
300257
mock.call(self.conv_op),
301258
mock.call(self.gamma_op),
302259
mock.call(self.beta_op),
303260
mock.call(self.batch_norm_op),
304261
mock.call(self.relu_op),
305-
mock.call(self.mean_op),
306-
mock.call(self.std_op),
307262
# Refreshing slice data.
308263
mock.call(self.conv_op),
309264
mock.call(self.gamma_op),
310265
mock.call(self.beta_op),
311266
mock.call(self.relu_op),
312-
mock.call(self.mean_op),
313-
mock.call(self.std_op),
314267
# Group batch norm op.
315268
mock.call(self.batch_norm_op)])
316269

317270
# Verify manager groups batch norm with inputs and outputs.
318271
self.mock_op_reg_manager.group_op_slices.assert_has_calls(
319-
[mock.call([self.batch_norm_op_slice, self.relu_op_slice,
320-
self.mean_op_slice, self.std_op_slice]),
272+
[mock.call([self.batch_norm_op_slice, self.relu_op_slice]),
321273
mock.call([self.batch_norm_op_slice, self.conv_op_slice,
322274
self.gamma_op_slice, self.beta_op_slice])])
323275

@@ -333,8 +285,6 @@ def testAssignGrouping_AllNeighborsGroupedSameGroup(self):
333285
self.relu_op_slice: self.batch_norm_op_group,
334286
self.gamma_op_slice: self.batch_norm_op_group,
335287
self.beta_op_slice: self.batch_norm_op_group,
336-
self.mean_op_slice: self.batch_norm_op_group,
337-
self.std_op_slice: self.batch_norm_op_group,
338288
}
339289

340290
# Call handler to assign grouping.
@@ -348,31 +298,23 @@ def testAssignGrouping_AllNeighborsGroupedSameGroup(self):
348298
mock.call(self.gamma_op),
349299
mock.call(self.beta_op),
350300
mock.call(self.relu_op),
351-
mock.call(self.mean_op),
352-
mock.call(self.std_op),
353301
# Initial slice data.
354302
mock.call(self.batch_norm_op),
355303
mock.call(self.conv_op),
356304
mock.call(self.gamma_op),
357305
mock.call(self.beta_op),
358306
mock.call(self.relu_op),
359-
mock.call(self.mean_op),
360-
mock.call(self.std_op),
361307
# Reslicing.
362308
mock.call(self.conv_op),
363309
mock.call(self.gamma_op),
364310
mock.call(self.beta_op),
365311
mock.call(self.batch_norm_op),
366312
mock.call(self.relu_op),
367-
mock.call(self.mean_op),
368-
mock.call(self.std_op),
369313
# Refreshing slice data.
370314
mock.call(self.conv_op),
371315
mock.call(self.gamma_op),
372316
mock.call(self.beta_op),
373317
mock.call(self.relu_op),
374-
mock.call(self.mean_op),
375-
mock.call(self.std_op),
376318
# Group batch norm op.
377319
mock.call(self.batch_norm_op)])
378320

@@ -400,8 +342,6 @@ def is_passthrough(op):
400342
self.relu_op_slice: self.relu_op_group,
401343
self.gamma_op_slice: self.conv_op_group,
402344
self.beta_op_slice: self.conv_op_group,
403-
self.mean_op_slice: self.relu_op_group,
404-
self.std_op_slice: self.relu_op_group,
405345
}
406346

407347
# Call handler to assign grouping.
@@ -415,37 +355,27 @@ def is_passthrough(op):
415355
mock.call(self.gamma_op),
416356
mock.call(self.beta_op),
417357
mock.call(self.relu_op),
418-
mock.call(self.mean_op),
419-
mock.call(self.std_op),
420358
# Initial slice data.
421359
mock.call(self.batch_norm_op),
422360
mock.call(self.conv_op),
423361
mock.call(self.gamma_op),
424362
mock.call(self.beta_op),
425-
mock.call(self.mean_op),
426-
mock.call(self.std_op),
427363
# Reslicing.
428364
mock.call(self.conv_op),
429365
mock.call(self.gamma_op),
430366
mock.call(self.beta_op),
431367
mock.call(self.batch_norm_op),
432-
mock.call(self.mean_op),
433-
mock.call(self.std_op),
434368
# Refreshing slice data.
435369
mock.call(self.conv_op),
436370
mock.call(self.gamma_op),
437371
mock.call(self.beta_op),
438-
mock.call(self.mean_op),
439-
mock.call(self.std_op),
440372
# Group batch norm op.
441373
mock.call(self.batch_norm_op)])
442374

443375
# Verify manager groups batch norm with inputs and outputs. ReLU is not
444376
# part of the grouping.
445377
self.mock_op_reg_manager.group_op_slices.assert_has_calls(
446-
[mock.call([self.batch_norm_op_slice, self.mean_op_slice,
447-
self.std_op_slice]),
448-
mock.call([self.batch_norm_op_slice, self.conv_op_slice,
378+
[mock.call([self.batch_norm_op_slice, self.conv_op_slice,
449379
self.gamma_op_slice, self.beta_op_slice])])
450380

451381
# Verify manager does not process any additional ops.
@@ -454,12 +384,11 @@ def is_passthrough(op):
454384

455385
def testGetInputOutputOpSlices(self):
456386
input_ops = [self.conv_op, self.gamma_op, self.beta_op]
457-
output_ops = [self.mean_op, self.std_op, self.relu_op]
387+
output_ops = [self.relu_op]
458388

459389
expected_input_op_slices = [
460390
[self.conv_op_slice], [self.gamma_op_slice], [self.beta_op_slice]]
461-
expected_output_op_slices = [
462-
[self.mean_op_slice], [self.std_op_slice], [self.relu_op_slice]]
391+
expected_output_op_slices = [[self.relu_op_slice]]
463392

464393
# Instantiate handler.
465394
handler = grouping_op_handler.GroupingOpHandler()

0 commit comments

Comments
 (0)