File tree Expand file tree Collapse file tree 2 files changed +9
-0
lines changed Expand file tree Collapse file tree 2 files changed +9
-0
lines changed Original file line number Diff line number Diff line change @@ -97,6 +97,9 @@ def maybe_convert_to_variable(tensor):
9797
9898def write_to_variable (tensor , fail_if_exists = True ):
9999 """Saves a tensor for later retrieval on CPU."""
100+ if not isinstance (tensor , tf .Tensor ):
101+ raise ValueError ('Expected tf.Tensor but got {}' .format (type (tensor )))
102+
100103 # Only relevant for debugging.
101104 debug_name = 'tpu_util__' + tensor .name .split (':' )[0 ]
102105
Original file line number Diff line number Diff line change @@ -104,5 +104,11 @@ def test_write_to_variable(self):
104104 tpu_util .write_to_variable (bar )
105105 self .assertLen (set (tpu_util .var_store .values ()), 2 )
106106
107+ def test_write_to_variable_check_inputs (self ):
108+ variable = tf .get_variable ('x' , shape = [1 ], dtype = tf .float32 )
109+ with self .assertRaises (ValueError ):
110+ tpu_util .write_to_variable (variable )
111+
112+
107113if __name__ == '__main__' :
108114 tf .test .main ()
You can’t perform that action at this time.
0 commit comments