Skip to content

Commit 84a0fa0

Browse files
pkchmn-robot
authored andcommitted
Fix batchnorm test.
PiperOrigin-RevId: 299189647
1 parent b02c316 commit 84a0fa0

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

morph_net/framework/op_handler_util_test.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def setUp(self):
6060
self.gamma_op = g.get_operation_by_name('conv1/BatchNorm/gamma/read')
6161
self.beta_op = g.get_operation_by_name('conv1/BatchNorm/beta/read')
6262
self.decay_op = g.get_operation_by_name('conv1/BatchNorm/Const')
63-
self.epsilon_op = g.get_operation_by_name('conv1/BatchNorm/Const_1')
6463
self.mean_op = g.get_operation_by_name(
6564
'conv1/BatchNorm/AssignMovingAvg/sub_1')
6665
self.std_op = g.get_operation_by_name(
@@ -115,7 +114,7 @@ def is_passthrough(op):
115114
self.mock_op_reg_manager.is_passthrough.side_effect = is_passthrough
116115
self.mock_op_reg_manager.ops = [
117116
self.batch_norm_op, self.gamma_op, self.beta_op, self.decay_op,
118-
self.epsilon_op, self.mean_op, self.std_op, self.conv_op, self.relu_op,
117+
self.mean_op, self.std_op, self.conv_op, self.relu_op,
119118
self.relu2_op, self.relu3_op, self.relu4_op, self.unfused_batch_norm_op,
120119
self.concat_op]
121120

@@ -129,7 +128,6 @@ def testGetInputOps(self):
129128
self.mock_op_reg_manager)
130129
self.assertEqual(expected_inputs, input_ops)
131130
self.assertNotIn(self.decay_op, input_ops)
132-
self.assertNotIn(self.epsilon_op, input_ops)
133131

134132
def testGetOutputOps(self):
135133
# For batch norm, the expected outputs are mean, std, and ReLU.
@@ -149,7 +147,6 @@ def testGetOpsWithoutGroups(self):
149147
self.gamma_op: [orm.OpSlice(self.gamma_op, None)],
150148
self.beta_op: [orm.OpSlice(self.beta_op, None)],
151149
self.decay_op: [orm.OpSlice(self.decay_op, None)],
152-
self.epsilon_op: [orm.OpSlice(self.epsilon_op, None)],
153150
}
154151

155152
# Only batch norm and conv ops have groups.
@@ -159,9 +156,9 @@ def testGetOpsWithoutGroups(self):
159156
}
160157

161158
all_ops = [self.batch_norm_op, self.conv_op, self.gamma_op, self.beta_op,
162-
self.decay_op, self.epsilon_op]
159+
self.decay_op]
163160
# Batch norm and conv ops have groups. The other ops do not have groups.
164-
expected_ops = [self.gamma_op, self.beta_op, self.decay_op, self.epsilon_op]
161+
expected_ops = [self.gamma_op, self.beta_op, self.decay_op]
165162
self.assertEqual(
166163
expected_ops,
167164
op_handler_util.get_ops_without_groups(
@@ -171,7 +168,7 @@ def testRemoveNonPassthroughOps(self):
171168
self._passthrough_ops = (self.gamma_op, self.decay_op, self.std_op)
172169

173170
all_ops = [self.batch_norm_op, self.conv_op, self.gamma_op, self.beta_op,
174-
self.decay_op, self.epsilon_op, self.mean_op]
171+
self.decay_op, self.mean_op]
175172
expected_ops = [self.gamma_op, self.decay_op]
176173

177174
self.assertListEqual(

0 commit comments

Comments
 (0)