@@ -60,7 +60,6 @@ def setUp(self):
6060 self .gamma_op = g .get_operation_by_name ('conv1/BatchNorm/gamma/read' )
6161 self .beta_op = g .get_operation_by_name ('conv1/BatchNorm/beta/read' )
6262 self .decay_op = g .get_operation_by_name ('conv1/BatchNorm/Const' )
63- self .epsilon_op = g .get_operation_by_name ('conv1/BatchNorm/Const_1' )
6463 self .mean_op = g .get_operation_by_name (
6564 'conv1/BatchNorm/AssignMovingAvg/sub_1' )
6665 self .std_op = g .get_operation_by_name (
@@ -115,7 +114,7 @@ def is_passthrough(op):
115114 self .mock_op_reg_manager .is_passthrough .side_effect = is_passthrough
116115 self .mock_op_reg_manager .ops = [
117116 self .batch_norm_op , self .gamma_op , self .beta_op , self .decay_op ,
118- self .epsilon_op , self . mean_op , self .std_op , self .conv_op , self .relu_op ,
117+ self .mean_op , self .std_op , self .conv_op , self .relu_op ,
119118 self .relu2_op , self .relu3_op , self .relu4_op , self .unfused_batch_norm_op ,
120119 self .concat_op ]
121120
@@ -129,7 +128,6 @@ def testGetInputOps(self):
129128 self .mock_op_reg_manager )
130129 self .assertEqual (expected_inputs , input_ops )
131130 self .assertNotIn (self .decay_op , input_ops )
132- self .assertNotIn (self .epsilon_op , input_ops )
133131
134132 def testGetOutputOps (self ):
135133 # For batch norm, the expected outputs are mean, std, and ReLU.
@@ -149,7 +147,6 @@ def testGetOpsWithoutGroups(self):
149147 self .gamma_op : [orm .OpSlice (self .gamma_op , None )],
150148 self .beta_op : [orm .OpSlice (self .beta_op , None )],
151149 self .decay_op : [orm .OpSlice (self .decay_op , None )],
152- self .epsilon_op : [orm .OpSlice (self .epsilon_op , None )],
153150 }
154151
155152 # Only batch norm and conv ops have groups.
@@ -159,9 +156,9 @@ def testGetOpsWithoutGroups(self):
159156 }
160157
161158 all_ops = [self .batch_norm_op , self .conv_op , self .gamma_op , self .beta_op ,
162- self .decay_op , self . epsilon_op ]
159+ self .decay_op ]
163160 # Batch norm and conv ops have groups. The other ops do not have groups.
164- expected_ops = [self .gamma_op , self .beta_op , self .decay_op , self . epsilon_op ]
161+ expected_ops = [self .gamma_op , self .beta_op , self .decay_op ]
165162 self .assertEqual (
166163 expected_ops ,
167164 op_handler_util .get_ops_without_groups (
@@ -171,7 +168,7 @@ def testRemoveNonPassthroughOps(self):
171168 self ._passthrough_ops = (self .gamma_op , self .decay_op , self .std_op )
172169
173170 all_ops = [self .batch_norm_op , self .conv_op , self .gamma_op , self .beta_op ,
174- self .decay_op , self .epsilon_op , self . mean_op ]
171+ self .decay_op , self .mean_op ]
175172 expected_ops = [self .gamma_op , self .decay_op ]
176173
177174 self .assertListEqual (
0 commit comments