@@ -2046,17 +2046,24 @@ def test_mixed_ndim_error(self):
20462046 def test_static_shape_inference (self ):
20472047 a = at .tensor (dtype = "int8" , shape = (2 , 3 ))
20482048 b = at .tensor (dtype = "int8" , shape = (2 , 5 ))
2049- assert at .join (1 , a , b ).type .shape == (2 , 8 )
2050- assert at .join (- 1 , a , b ).type .shape == (2 , 8 )
2049+
2050+ res = at .join (1 , a , b ).type .shape
2051+ assert res == (2 , 8 )
2052+ assert all (isinstance (s , int ) for s in res )
2053+
2054+ res = at .join (- 1 , a , b ).type .shape
2055+ assert res == (2 , 8 )
2056+ assert all (isinstance (s , int ) for s in res )
20512057
20522058 # Check early informative errors from static shape info
20532059 with pytest .raises (ValueError , match = "must match exactly" ):
20542060 at .join (0 , at .ones ((2 , 3 )), at .ones ((2 , 5 )))
20552061
20562062 # Check partial inference
20572063 d = at .tensor (dtype = "int8" , shape = (2 , None ))
2058- assert at .join (1 , a , b , d ).type .shape == (2 , None )
2059- return
2064+ res = at .join (1 , a , b , d ).type .shape
2065+ assert res == (2 , None )
2066+ assert isinstance (res [0 ], int )
20602067
20612068 def test_split_0elem (self ):
20622069 rng = np .random .default_rng (seed = utt .fetch_seed ())
0 commit comments