Skip to content

Commit 34413ca

Browse files
shraman-rcmn-robot
authored andcommitted
Check for tf.Tensor type in tpu_util.write_to_variable.
PiperOrigin-RevId: 321891941
1 parent d41667f commit 34413ca

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

morph_net/framework/tpu_util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ def maybe_convert_to_variable(tensor):
9797

9898
def write_to_variable(tensor, fail_if_exists=True):
9999
"""Saves a tensor for later retrieval on CPU."""
100+
if not isinstance(tensor, tf.Tensor):
101+
raise ValueError('Expected tf.Tensor but got {}'.format(type(tensor)))
102+
100103
# Only relevant for debugging.
101104
debug_name = 'tpu_util__' + tensor.name.split(':')[0]
102105

morph_net/framework/tpu_util_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,5 +104,11 @@ def test_write_to_variable(self):
104104
tpu_util.write_to_variable(bar)
105105
self.assertLen(set(tpu_util.var_store.values()), 2)
106106

107+
def test_write_to_variable_check_inputs(self):
108+
variable = tf.get_variable('x', shape=[1], dtype=tf.float32)
109+
with self.assertRaises(ValueError):
110+
tpu_util.write_to_variable(variable)
111+
112+
107113
if __name__ == '__main__':
108114
tf.test.main()

0 commit comments

Comments
 (0)