@@ -1800,16 +1800,61 @@ def func():
18001800 # since results are random, compare the shapes only
18011801 self ._run_test_case (func , [_OUTPUT ], {}, check_value = False , check_shape = True )
18021802
1803- @unittest .skip ("TF RandomUniformInt is not supported" )
18041803 def test_randomuniform_int (self ):
18051804 def func ():
1806- shape = tf .constant ([2 , 3 ], name = "shape" )
1807- x_ = random_uniform (shape , name = "rand" , dtype = tf .int32 , maxval = 10 )
1805+ shape = tf .constant ([100 , 3 ], name = "shape" )
1806+ x_ = random_uniform (shape , name = "rand" , dtype = tf .int32 , minval = 2 , maxval = 10 )
18081807 x_ = tf .identity (x_ , name = "output1" )
18091808 x_ = tf .identity (x_ , name = "output2" )
18101809 return tf .identity (x_ , name = _TFOUTPUT )
18111810 # since results are random, compare the shapes only
1812- self ._run_test_case (func , [_OUTPUT ], {}, check_value = False , check_shape = True )
1811+ g = self ._run_test_case (func , [_OUTPUT ], {}, check_value = False , check_shape = True )
1812+ results = self .run_backend (g , [_OUTPUT ], {})
1813+ numbers = set (results [0 ].flatten ())
1814+ self .assertEqual (sorted (numbers ), list (range (2 , 10 )))
1815+
1816+ def test_randomuniform_int_nonconst_max (self ):
1817+ m_val = np .array (8 , dtype = np .int32 )
1818+ def func (m ):
1819+ shape = tf .constant ([100 , 3 ], name = "shape" )
1820+ x_ = random_uniform (shape , name = "rand" , dtype = tf .int32 , minval = 0 , maxval = m )
1821+ x_ = tf .identity (x_ , name = "output1" )
1822+ x_ = tf .identity (x_ , name = "output2" )
1823+ return tf .identity (x_ , name = _TFOUTPUT )
1824+ g = self ._run_test_case (func , [_OUTPUT ], {_INPUT : m_val }, check_value = False , check_shape = True )
1825+ results = self .run_backend (g , [_OUTPUT ], {_INPUT : m_val })
1826+ numbers = set (results [0 ].flatten ())
1827+ self .assertEqual (sorted (numbers ), list (range (8 )))
1828+
1829+ def test_randomuniform_int_nonconst_min_max (self ):
1830+ n_val = np .array (2 , dtype = np .int32 )
1831+ m_val = np .array (10 , dtype = np .int32 )
1832+ def func (n , m ):
1833+ shape = tf .constant ([100 , 3 ], name = "shape" )
1834+ x_ = random_uniform (shape , name = "rand" , dtype = tf .int32 , minval = n , maxval = m )
1835+ x_ = tf .identity (x_ , name = "output1" )
1836+ x_ = tf .identity (x_ , name = "output2" )
1837+ return tf .identity (x_ , name = _TFOUTPUT )
1838+ g = self ._run_test_case (func , [_OUTPUT ], {_INPUT : n_val , _INPUT1 : m_val }, check_value = False , check_shape = True )
1839+ results = self .run_backend (g , [_OUTPUT ], {_INPUT : n_val , _INPUT1 : m_val })
1840+ numbers = set (results [0 ].flatten ())
1841+ self .assertEqual (sorted (numbers ), list (range (2 , 10 )))
1842+
1843+ @check_opset_min_version (9 , "RandomUniformLike" )
1844+ def test_randomuniform_int_nonconst_min_max_shape (self ):
1845+ n_val = np .array (2 , dtype = np .int32 )
1846+ m_val = np .array (10 , dtype = np .int32 )
1847+ s_val = np .array ([100 , 3 ], dtype = np .int64 )
1848+ def func (n , m , s ):
1849+ x_ = random_uniform (s , name = "rand" , dtype = tf .int32 , minval = n , maxval = m )
1850+ x_ = tf .identity (x_ , name = "output1" )
1851+ x_ = tf .identity (x_ , name = "output2" )
1852+ return tf .identity (x_ , name = _TFOUTPUT )
1853+ g = self ._run_test_case (func , [_OUTPUT ], {_INPUT : n_val , _INPUT1 : m_val , _INPUT2 : s_val },
1854+ check_value = False , check_shape = True )
1855+ results = self .run_backend (g , [_OUTPUT ], {_INPUT : n_val , _INPUT1 : m_val , _INPUT2 : s_val })
1856+ numbers = set (results [0 ].flatten ())
1857+ self .assertEqual (sorted (numbers ), list (range (2 , 10 )))
18131858
18141859 @skip_caffe2_backend ()
18151860 @check_opset_after_tf_version ("2.2" , 9 , "RandomUniform" )
0 commit comments