Skip to content

Commit 04fa416

Browse files
pkchmn-robot
authored andcommitted
Internal change
PiperOrigin-RevId: 299931864
1 parent 34c5ad1 commit 04fa416

File tree

2 files changed

+13
-70
lines changed

2 files changed

+13
-70
lines changed

morph_net/framework/batch_norm_source_op_handler_test.py

Lines changed: 5 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -65,18 +65,6 @@ def setUp(self):
6565
self.beta_op_group = orm.OpGroup(
6666
self.beta_op_slice, omit_source_op_slices=[self.beta_op_slice])
6767

68-
self.mean_op = g.get_operation_by_name(
69-
'conv1/BatchNorm/AssignMovingAvg/sub_1')
70-
self.mean_op_slice = orm.OpSlice(self.mean_op, orm.Slice(0, 5))
71-
self.mean_op_group = orm.OpGroup(
72-
self.mean_op_slice, omit_source_op_slices=[self.mean_op_slice])
73-
74-
self.std_op = g.get_operation_by_name(
75-
'conv1/BatchNorm/AssignMovingAvg_1/sub_1')
76-
self.std_op_slice = orm.OpSlice(self.std_op, orm.Slice(0, 5))
77-
self.std_op_group = orm.OpGroup(
78-
self.std_op_slice, omit_source_op_slices=[self.std_op_slice])
79-
8068
# Create custom mapping of OpSlice and OpGroup in manager.
8169
self.mock_op_reg_manager = mock.create_autospec(orm.OpRegularizerManager)
8270

@@ -91,7 +79,7 @@ def get_op_group(op_slice):
9179
self.mock_op_reg_manager.is_source_op.return_value = False
9280
self.mock_op_reg_manager.ops = [
9381
self.batch_norm_op, self.conv_op, self.relu_op, self.gamma_op,
94-
self.beta_op, self.mean_op, self.std_op]
82+
self.beta_op]
9583

9684
def testAssignGrouping_NoNeighborGroups(self):
9785
self.op_slice_dict = {
@@ -100,8 +88,6 @@ def testAssignGrouping_NoNeighborGroups(self):
10088
self.relu_op: [self.relu_op_slice],
10189
self.gamma_op: [self.gamma_op_slice],
10290
self.beta_op: [self.beta_op_slice],
103-
self.mean_op: [self.mean_op_slice],
104-
self.std_op: [self.std_op_slice],
10591
}
10692

10793
# No neighbor ops have groups.
@@ -128,7 +114,7 @@ def testAssignGrouping_NoNeighborGroups(self):
128114

129115
# Verify manager processes grouping for input and output ops.
130116
self.mock_op_reg_manager.process_ops.assert_called_once_with(
131-
[self.relu_op, self.mean_op, self.std_op, self.conv_op, self.gamma_op,
117+
[self.relu_op, self.conv_op, self.gamma_op,
132118
self.beta_op])
133119
self.mock_op_reg_manager.process_ops_last.assert_called_once_with(
134120
[self.batch_norm_op])
@@ -140,8 +126,6 @@ def testAssignGrouping_ProcessNeighborGroups(self):
140126
self.relu_op: [self.relu_op_slice],
141127
self.gamma_op: [self.gamma_op_slice],
142128
self.beta_op: [self.beta_op_slice],
143-
self.mean_op: [self.mean_op_slice],
144-
self.std_op: [self.std_op_slice],
145129
}
146130

147131
# All ops have groups.
@@ -151,8 +135,6 @@ def testAssignGrouping_ProcessNeighborGroups(self):
151135
self.relu_op_slice: self.relu_op_group,
152136
self.gamma_op_slice: self.gamma_op_group,
153137
self.beta_op_slice: self.beta_op_group,
154-
self.mean_op_slice: self.mean_op_group,
155-
self.std_op_slice: self.std_op_group,
156138
}
157139

158140
# Call handler to assign grouping.
@@ -167,8 +149,7 @@ def testAssignGrouping_ProcessNeighborGroups(self):
167149

168150
# Verify manager groups batch norm with outputs and inputs.
169151
self.mock_op_reg_manager.group_op_slices.assert_has_calls(
170-
[mock.call([self.batch_norm_op_slice, self.relu_op_slice,
171-
self.mean_op_slice, self.std_op_slice]),
152+
[mock.call([self.batch_norm_op_slice, self.relu_op_slice]),
172153
mock.call([self.batch_norm_op_slice, self.conv_op_slice,
173154
self.gamma_op_slice, self.beta_op_slice])])
174155

@@ -208,28 +189,12 @@ def testAssignGrouping_ProcessNeighborGroupsWithSlices(self):
208189
beta_op_group2 = orm.OpGroup(
209190
beta_op_slice_2_3, omit_source_op_slices=[beta_op_slice_2_3])
210191

211-
mean_op_slice_0_2 = orm.OpSlice(self.mean_op, orm.Slice(0, 2))
212-
mean_op_slice_2_3 = orm.OpSlice(self.mean_op, orm.Slice(2, 1))
213-
mean_op_group1 = orm.OpGroup(
214-
mean_op_slice_0_2, omit_source_op_slices=[mean_op_slice_0_2])
215-
mean_op_group2 = orm.OpGroup(
216-
mean_op_slice_2_3, omit_source_op_slices=[mean_op_slice_2_3])
217-
218-
std_op_slice_0_2 = orm.OpSlice(self.std_op, orm.Slice(0, 2))
219-
std_op_slice_2_3 = orm.OpSlice(self.std_op, orm.Slice(2, 1))
220-
std_op_group1 = orm.OpGroup(
221-
std_op_slice_0_2, omit_source_op_slices=[std_op_slice_0_2])
222-
std_op_group2 = orm.OpGroup(
223-
std_op_slice_2_3, omit_source_op_slices=[std_op_slice_2_3])
224-
225192
self.op_slice_dict = {
226193
self.batch_norm_op: [batch_norm_op_slice_0_2, batch_norm_op_slice_2_3],
227194
self.conv_op: [conv_op_slice_0_2, conv_op_slice_2_3],
228195
self.relu_op: [relu_op_slice_0_2, relu_op_slice_2_3],
229196
self.gamma_op: [gamma_op_slice_0_2, gamma_op_slice_2_3],
230197
self.beta_op: [beta_op_slice_0_2, beta_op_slice_2_3],
231-
self.mean_op: [mean_op_slice_0_2, mean_op_slice_2_3],
232-
self.std_op: [std_op_slice_0_2, std_op_slice_2_3],
233198
}
234199

235200
# All OpSlice have groups.
@@ -244,10 +209,6 @@ def testAssignGrouping_ProcessNeighborGroupsWithSlices(self):
244209
gamma_op_slice_2_3: gamma_op_group2,
245210
beta_op_slice_0_2: beta_op_group1,
246211
beta_op_slice_2_3: beta_op_group2,
247-
mean_op_slice_0_2: mean_op_group1,
248-
mean_op_slice_2_3: mean_op_group2,
249-
std_op_slice_0_2: std_op_group1,
250-
std_op_slice_2_3: std_op_group2,
251212
}
252213

253214
# Call handler to assign grouping.
@@ -262,12 +223,10 @@ def testAssignGrouping_ProcessNeighborGroupsWithSlices(self):
262223

263224
# Verify manager groups batch norm with outputs and inputs by slice.
264225
self.mock_op_reg_manager.group_op_slices.assert_has_calls(
265-
[mock.call([batch_norm_op_slice_0_2, relu_op_slice_0_2,
266-
mean_op_slice_0_2, std_op_slice_0_2]),
226+
[mock.call([batch_norm_op_slice_0_2, relu_op_slice_0_2]),
267227
mock.call([batch_norm_op_slice_0_2, conv_op_slice_0_2,
268228
gamma_op_slice_0_2, beta_op_slice_0_2]),
269-
mock.call([batch_norm_op_slice_2_3, relu_op_slice_2_3,
270-
mean_op_slice_2_3, std_op_slice_2_3]),
229+
mock.call([batch_norm_op_slice_2_3, relu_op_slice_2_3]),
271230
mock.call([batch_norm_op_slice_2_3, conv_op_slice_2_3,
272231
gamma_op_slice_2_3, beta_op_slice_2_3])])
273232

@@ -282,8 +241,6 @@ def testAssignGrouping_NeighborsHaveSameGroup(self):
282241
self.relu_op: [self.batch_norm_op_slice],
283242
self.gamma_op: [self.batch_norm_op_slice],
284243
self.beta_op: [self.batch_norm_op_slice],
285-
self.mean_op: [self.batch_norm_op_slice],
286-
self.std_op: [self.batch_norm_op_slice],
287244
}
288245

289246
# All ops have the same group.
@@ -320,8 +277,6 @@ def testAssignGrouping_NeighborsHaveSameGroup_ReprocessSources(self):
320277
self.relu_op: [self.relu_op_slice],
321278
self.gamma_op: [self.gamma_op_slice],
322279
self.beta_op: [self.beta_op_slice],
323-
self.mean_op: [self.mean_op_slice],
324-
self.std_op: [self.std_op_slice],
325280
}
326281

327282
self.op_group_dict = {
@@ -353,9 +308,6 @@ def is_source_op(op):
353308
[self.batch_norm_op_slice, self.conv_op_slice, self.gamma_op_slice,
354309
self.beta_op_slice])
355310

356-
# Verify manager adds ungrouped output ops to queue.
357-
self.mock_op_reg_manager.process_ops.assert_called_once_with(
358-
[self.mean_op, self.std_op])
359311
self.mock_op_reg_manager.process_ops_last.assert_not_called()
360312

361313
def testCreateRegularizer(self):

morph_net/framework/op_handler_util_test.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,6 @@ def setUp(self):
5959

6060
self.gamma_op = g.get_operation_by_name('conv1/BatchNorm/gamma/read')
6161
self.beta_op = g.get_operation_by_name('conv1/BatchNorm/beta/read')
62-
self.decay_op = g.get_operation_by_name('conv1/BatchNorm/Const')
63-
self.mean_op = g.get_operation_by_name(
64-
'conv1/BatchNorm/AssignMovingAvg/sub_1')
65-
self.std_op = g.get_operation_by_name(
66-
'conv1/BatchNorm/AssignMovingAvg_1/sub_1')
6762

6863
self.relu_op = g.get_operation_by_name('conv1/Relu')
6964
self.relu_op_slice = orm.OpSlice(self.relu_op, None)
@@ -113,8 +108,8 @@ def is_passthrough(op):
113108
self.mock_op_reg_manager.get_op_group.side_effect = get_op_group
114109
self.mock_op_reg_manager.is_passthrough.side_effect = is_passthrough
115110
self.mock_op_reg_manager.ops = [
116-
self.batch_norm_op, self.gamma_op, self.beta_op, self.decay_op,
117-
self.mean_op, self.std_op, self.conv_op, self.relu_op,
111+
self.batch_norm_op, self.gamma_op, self.beta_op,
112+
self.conv_op, self.relu_op,
118113
self.relu2_op, self.relu3_op, self.relu4_op, self.unfused_batch_norm_op,
119114
self.concat_op]
120115

@@ -127,11 +122,10 @@ def testGetInputOps(self):
127122
input_ops = op_handler_util.get_input_ops(self.batch_norm_op,
128123
self.mock_op_reg_manager)
129124
self.assertEqual(expected_inputs, input_ops)
130-
self.assertNotIn(self.decay_op, input_ops)
131125

132126
def testGetOutputOps(self):
133127
# For batch norm, the expected outputs are mean, std, and ReLU.
134-
expected_outputs = [self.relu_op, self.mean_op, self.std_op]
128+
expected_outputs = [self.relu_op]
135129

136130
# Check for expected output ops.
137131
self.assertEqual(
@@ -146,7 +140,6 @@ def testGetOpsWithoutGroups(self):
146140
self.conv_op: [self.conv_op_slice],
147141
self.gamma_op: [orm.OpSlice(self.gamma_op, None)],
148142
self.beta_op: [orm.OpSlice(self.beta_op, None)],
149-
self.decay_op: [orm.OpSlice(self.decay_op, None)],
150143
}
151144

152145
# Only batch norm and conv ops have groups.
@@ -155,21 +148,19 @@ def testGetOpsWithoutGroups(self):
155148
self.conv_op_slice: self.conv_op_group
156149
}
157150

158-
all_ops = [self.batch_norm_op, self.conv_op, self.gamma_op, self.beta_op,
159-
self.decay_op]
151+
all_ops = [self.batch_norm_op, self.conv_op, self.gamma_op, self.beta_op]
160152
# Batch norm and conv ops have groups. The other ops do not have groups.
161-
expected_ops = [self.gamma_op, self.beta_op, self.decay_op]
153+
expected_ops = [self.gamma_op, self.beta_op]
162154
self.assertEqual(
163155
expected_ops,
164156
op_handler_util.get_ops_without_groups(
165157
all_ops, self.mock_op_reg_manager))
166158

167159
def testRemoveNonPassthroughOps(self):
168-
self._passthrough_ops = (self.gamma_op, self.decay_op, self.std_op)
160+
self._passthrough_ops = (self.gamma_op,)
169161

170-
all_ops = [self.batch_norm_op, self.conv_op, self.gamma_op, self.beta_op,
171-
self.decay_op, self.mean_op]
172-
expected_ops = [self.gamma_op, self.decay_op]
162+
all_ops = [self.batch_norm_op, self.conv_op, self.gamma_op, self.beta_op]
163+
expected_ops = [self.gamma_op]
173164

174165
self.assertListEqual(
175166
expected_ops,

0 commit comments

Comments
 (0)